Improve internal logging functionality by using Python's standard logging module (#1956)
Some checks are pending
gh-pages / build (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.11) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.10) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.11) (push) Waiting to run
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Publish (pypi) / Upload release to PyPI (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Spellcheck / spellcheck (push) Waiting to run

* Initial plan for issue

* Implement standard logging module and integrate with existing loggers

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Add test cases and improve documentation for standard logging

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Apply ruff formatting and add semversioner file for logging improvements

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove custom logger classes and refactor to use standard logging only

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Apply ruff formatting to resolve CI/CD test failures

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Add semversioner file and fix linting issues

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* ruff fixes

* fix spelling error

* Remove StandardProgressLogger and refactor to use standard logging

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove LoggerFactory and custom loggers, refactor to use standard logging

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix pyright error: use logger.info() instead of calling logger as function in cosmosdb_pipeline_storage.py

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* ruff fixes

* Remove deprecated logger files that were marked as deprecated placeholders

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Replace custom get_logger with standard Python logging

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix linting issues found by ruff check --fix

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* apply ruff check fixes

* add word to dictionary

* Fix type checker error in ModelManager.__new__ method

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Refactor multiple logging.getLogger() calls to use single logger per file

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove progress_logger parameter from build_index() and logger parameter from generate_indexing_prompts()

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove logger parameter from run_pipeline and standardize logger naming

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Replace logger parameter with log_level parameter in CLI commands

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix import ordering in notebook files to pass poetry poe check

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove --logger parameter from smoke test command

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix Windows CI/CD issue with log file cleanup in tests

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Add StreamHandler to root logger in __main__.py for CLI logging

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Only add StreamHandler if root logger doesn't have existing StreamHandler

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix import ordering in notebook files to pass ruff checks

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Replace logging.StreamHandler with colorlog.StreamHandler for colorized log output

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Regenerate poetry.lock file after adding colorlog dependency

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix import ordering in notebook files to pass ruff checks

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* move printing of dataframes to debug level

* remove colorlog for now

* Refactor workflow callbacks to inherit from logging.Handler

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix linting issues in workflow callback handlers

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix pyright type errors in blob and file workflow callbacks

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Refactor pipeline logging to use pure logging.Handler subclasses

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Rename workflow callback classes to workflow logger classes and move to logger directory

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* update dictionary

* apply ruff fixes

* fix function name

* simplify logger code

* update

* Remove error, warning, and log methods from WorkflowCallbacks and replace with standard logging

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* ruff fixes

* Fix pyright errors by removing WorkflowCallbacks from strategy type signatures

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove ConsoleWorkflowLogger and apply consistent formatter to all handlers

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* apply ruff fixes

* Refactor pipeline_logger.py to use standard FileHandler and remove FileWorkflowLogger

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove conditional azure import checks from blob_workflow_logger.py

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix pyright type checking errors in mock_provider.py and utils.py

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Run ruff check --fix to fix import ordering in notebooks

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Merge configure_logging and create_pipeline_logger into init_loggers function

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove configure_logging and create_pipeline_logger functions, replace all usage with init_loggers

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* apply ruff fixes

* cleanup unused code

* Update init_loggers to accept GraphRagConfig instead of ReportingConfig

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* apply ruff check fixes

* Fix test failures by providing valid GraphRagConfig with required model configurations

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* apply ruff fixes

* remove logging_workflow_callback

* cleanup logging messages

* Add logging to track progress of pandas DataFrame apply operation in create_base_text_units

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* cleanup logger logic throughout codebase

* update

* more cleanup of old loggers

* small logger cleanup

* final code cleanup and added loggers to query

* add verbose logging to query

* minor code cleanup

* Fix broken unit tests for chunk_text and standard_logging

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* apply ruff fixes

* Fix test_chunk_text by mocking progress_ticker function instead of ProgressTicker class

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* remove unnecessary logger

* remove rich and fix type annotation

* revert test formatting changes my by copilot

* promote graphrag logs to root logger

* add correct semversioner file

* revert change to file

* revert formatting changes that have no effect

* fix changes after merge with main

* revert unnecessary copilot changes

* remove whitespace

* cleanup docstring

* simplify some logic with less code

* update poetry lock file

* ruff fixes

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>
Co-authored-by: Josh Bradley <joshbradley@microsoft.com>
This commit is contained in:
Copilot 2025-07-09 18:29:03 -06:00 committed by GitHub
parent 27c6de846f
commit e84df28e64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
128 changed files with 2125 additions and 2015 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "cleaned up logging to follow python standards."
}

View File

@ -102,6 +102,7 @@ itertuples
isin
nocache
nbconvert
levelno
# HTML
nbsp
@ -186,6 +187,7 @@ Verdantis's
# English
skippable
upvote
unconfigured
# Misc
Arxiv

View File

@ -178,11 +178,11 @@ This section controls the cache mechanism used by the pipeline. This is used to
### Reporting
This section controls the reporting mechanism used by the pipeline, for common events and error messages. The default is to write reports to a file in the output directory. However, you can also choose to write reports to the console or to an Azure Blob Storage container.
This section controls the reporting mechanism used by the pipeline, for common events and error messages. The default is to write reports to a file in the output directory. However, you can also choose to write reports to an Azure Blob Storage container.
| Parameter | Description | Type | Required or Optional | Default |
| --------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- | -------------------- | ------- |
| `GRAPHRAG_REPORTING_TYPE` | The type of reporter to use. Options are `file`, `console`, or `blob` | `str` | optional | `file` |
| `GRAPHRAG_REPORTING_TYPE` | The type of reporter to use. Options are `file` or `blob` | `str` | optional | `file` |
| `GRAPHRAG_REPORTING_STORAGE_ACCOUNT_BLOB_URL` | The Azure Storage blob endpoint to use when in `blob` mode and using managed identity. Will have the format `https://<storage_account_name>.blob.core.windows.net` | `str` | optional | None |
| `GRAPHRAG_REPORTING_CONNECTION_STRING` | The Azure Storage connection string to use when in `blob` mode. | `str` | optional | None |
| `GRAPHRAG_REPORTING_CONTAINER_NAME` | The Azure Storage container name to use when in `blob` mode. | `str` | optional | None |

View File

@ -149,11 +149,11 @@ This section controls the cache mechanism used by the pipeline. This is used to
### reporting
This section controls the reporting mechanism used by the pipeline, for common events and error messages. The default is to write reports to a file in the output directory. However, you can also choose to write reports to the console or to an Azure Blob Storage container.
This section controls the reporting mechanism used by the pipeline, for common events and error messages. The default is to write reports to a file in the output directory. However, you can also choose to write reports to an Azure Blob Storage container.
#### Fields
- `type` **file|console|blob** - The reporting type to use. Default=`file`
- `type` **file|blob** - The reporting type to use. Default=`file`
- `base_dir` **str** - The base directory to write reports to, relative to the root.
- `connection_string` **str** - (blob only) The Azure Storage connection string.
- `container_name` **str** - (blob only) The Azure Storage container name.

View File

@ -2,3 +2,10 @@
# Licensed under the MIT License
"""The GraphRAG package."""
import logging
from graphrag.logger.standard_logging import init_console_logger
logger = logging.getLogger(__name__)
init_console_logger()

View File

@ -10,7 +10,7 @@ Backwards compatibility is not guaranteed at this time.
import logging
from graphrag.callbacks.reporting import create_pipeline_reporter
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import IndexingMethod
from graphrag.config.models.graph_rag_config import GraphRagConfig
@ -19,10 +19,9 @@ from graphrag.index.run.utils import create_callback_chain
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
from graphrag.index.typing.workflow import WorkflowFunction
from graphrag.index.workflows.factory import PipelineFactory
from graphrag.logger.base import ProgressLogger
from graphrag.logger.null_progress import NullProgressLogger
from graphrag.logger.standard_logging import init_loggers
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
async def build_index(
@ -31,7 +30,6 @@ async def build_index(
is_update_run: bool = False,
memory_profile: bool = False,
callbacks: list[WorkflowCallbacks] | None = None,
progress_logger: ProgressLogger | None = None,
) -> list[PipelineRunResult]:
"""Run the pipeline with the given configuration.
@ -45,26 +43,25 @@ async def build_index(
Whether to enable memory profiling.
callbacks : list[WorkflowCallbacks] | None default=None
A list of callbacks to register.
progress_logger : ProgressLogger | None default=None
The progress logger.
Returns
-------
list[PipelineRunResult]
The list of pipeline run results
"""
logger = progress_logger or NullProgressLogger()
# create a pipeline reporter and add to any additional callbacks
callbacks = callbacks or []
callbacks.append(create_pipeline_reporter(config.reporting, None))
init_loggers(config=config)
workflow_callbacks = create_callback_chain(callbacks, logger)
# Create callbacks for pipeline lifecycle events if provided
workflow_callbacks = (
create_callback_chain(callbacks) if callbacks else NoopWorkflowCallbacks()
)
outputs: list[PipelineRunResult] = []
if memory_profile:
log.warning("New pipeline does not yet support memory profiling.")
logger.warning("New pipeline does not yet support memory profiling.")
logger.info("Initializing indexing pipeline...")
# todo: this could propagate out to the cli for better clarity, but will be a breaking api change
method = _get_method(method, is_update_run)
pipeline = PipelineFactory.create_pipeline(config, method)
@ -75,15 +72,14 @@ async def build_index(
pipeline,
config,
callbacks=workflow_callbacks,
logger=logger,
is_update_run=is_update_run,
):
outputs.append(output)
if output.errors and len(output.errors) > 0:
logger.error(output.workflow)
logger.error("Workflow %s completed with errors", output.workflow)
else:
logger.success(output.workflow)
logger.info(str(output.result))
logger.info("Workflow %s completed successfully", output.workflow)
logger.debug(str(output.result))
workflow_callbacks.pipeline_end(outputs)
return outputs

View File

@ -11,6 +11,7 @@ WARNING: This API is under development and may undergo changes in future release
Backwards compatibility is not guaranteed at this time.
"""
import logging
from typing import Annotated
import annotated_types
@ -20,7 +21,7 @@ from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.language_model.manager import ModelManager
from graphrag.logger.base import ProgressLogger
from graphrag.logger.standard_logging import init_loggers
from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT, PROMPT_TUNING_MODEL_ID
from graphrag.prompt_tune.generator.community_report_rating import (
generate_community_report_rating,
@ -47,11 +48,12 @@ from graphrag.prompt_tune.generator.persona import generate_persona
from graphrag.prompt_tune.loader.input import load_docs_in_chunks
from graphrag.prompt_tune.types import DocSelectionType
logger = logging.getLogger(__name__)
@validate_call(config={"arbitrary_types_allowed": True})
async def generate_indexing_prompts(
config: GraphRagConfig,
logger: ProgressLogger,
chunk_size: PositiveInt = graphrag_config_defaults.chunks.size,
overlap: Annotated[
int, annotated_types.Gt(-1)
@ -71,8 +73,6 @@ async def generate_indexing_prompts(
Parameters
----------
- config: The GraphRag configuration.
- logger: The logger to use for progress updates.
- root: The root directory.
- output_path: The path to store the prompts.
- chunk_size: The chunk token size to use for input text units.
- limit: The limit of chunks to load.
@ -89,6 +89,8 @@ async def generate_indexing_prompts(
-------
tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt
"""
init_loggers(config=config)
# Retrieve documents
logger.info("Chunking documents...")
doc_list = await load_docs_in_chunks(
@ -187,9 +189,9 @@ async def generate_indexing_prompts(
language=language,
)
logger.info(f"\nGenerated domain: {domain}") # noqa: G004
logger.info(f"\nDetected language: {language}") # noqa: G004
logger.info(f"\nGenerated persona: {persona}") # noqa: G004
logger.debug("Generated domain: %s", domain)
logger.debug("Detected language: %s", language)
logger.debug("Generated persona: %s", persona)
return (
extract_graph_prompt,

View File

@ -17,6 +17,7 @@ WARNING: This API is under development and may undergo changes in future release
Backwards compatibility is not guaranteed at this time.
"""
import logging
from collections.abc import AsyncGenerator
from typing import Any
@ -31,7 +32,7 @@ from graphrag.config.embeddings import (
text_unit_text_embedding,
)
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.logger.standard_logging import init_loggers
from graphrag.query.factory import (
get_basic_search_engine,
get_drift_search_engine,
@ -50,11 +51,13 @@ from graphrag.query.indexer_adapters import (
from graphrag.utils.api import (
get_embedding_store,
load_search_prompt,
truncate,
update_context_data,
)
from graphrag.utils.cli import redact
logger = PrintProgressLogger("")
# Initialize standard logger
logger = logging.getLogger(__name__)
@validate_call(config={"arbitrary_types_allowed": True})
@ -68,6 +71,7 @@ async def global_search(
response_type: str,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
@ -88,11 +92,9 @@ async def global_search(
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
init_loggers(config=config, verbose=verbose)
callbacks = callbacks or []
full_response = ""
context_data = {}
@ -105,6 +107,7 @@ async def global_search(
local_callbacks.on_context = on_context
callbacks.append(local_callbacks)
logger.debug("Executing global search query: %s", query)
async for chunk in global_search_streaming(
config=config,
entities=entities,
@ -117,6 +120,7 @@ async def global_search(
callbacks=callbacks,
):
full_response += chunk
logger.debug("Query response: %s", truncate(full_response, 400))
return full_response, context_data
@ -131,6 +135,7 @@ def global_search_streaming(
response_type: str,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> AsyncGenerator:
"""Perform a global search and return the context data and response via a generator.
@ -150,11 +155,9 @@ def global_search_streaming(
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
init_loggers(config=config, verbose=verbose)
communities_ = read_indexer_communities(communities, community_reports)
reports = read_indexer_reports(
community_reports,
@ -173,6 +176,7 @@ def global_search_streaming(
config.root_dir, config.global_search.knowledge_prompt
)
logger.debug("Executing streaming global search query: %s", query)
search_engine = get_global_search_engine(
config,
reports=reports,
@ -201,6 +205,7 @@ async def multi_index_global_search(
streaming: bool,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
@ -223,11 +228,9 @@ async def multi_index_global_search(
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
init_loggers(config=config, verbose=verbose)
# Streaming not supported yet
if streaming:
message = "Streaming not yet implemented for multi_global_search"
@ -311,6 +314,7 @@ async def multi_index_global_search(
communities_dfs, axis=0, ignore_index=True, sort=False
)
logger.debug("Executing multi-index global search query: %s", query)
result = await global_search(
config,
entities=entities_combined,
@ -326,6 +330,7 @@ async def multi_index_global_search(
# Update the context data by linking index names and community ids
context = update_context_data(result[1], links)
logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore
return (result[0], context)
@ -342,6 +347,7 @@ async def local_search(
response_type: str,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
@ -362,11 +368,9 @@ async def local_search(
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
init_loggers(config=config, verbose=verbose)
callbacks = callbacks or []
full_response = ""
context_data = {}
@ -379,6 +383,7 @@ async def local_search(
local_callbacks.on_context = on_context
callbacks.append(local_callbacks)
logger.debug("Executing local search query: %s", query)
async for chunk in local_search_streaming(
config=config,
entities=entities,
@ -393,6 +398,7 @@ async def local_search(
callbacks=callbacks,
):
full_response += chunk
logger.debug("Query response: %s", truncate(full_response, 400))
return full_response, context_data
@ -409,6 +415,7 @@ def local_search_streaming(
response_type: str,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> AsyncGenerator:
"""Perform a local search and return the context data and response via a generator.
@ -427,16 +434,14 @@ def local_search_streaming(
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
init_loggers(config=config, verbose=verbose)
vector_store_args = {}
for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump()
msg = f"Vector Store Args: {redact(vector_store_args)}"
logger.info(msg)
logger.debug(msg)
description_embedding_store = get_embedding_store(
config_args=vector_store_args,
@ -447,6 +452,7 @@ def local_search_streaming(
covariates_ = read_indexer_covariates(covariates) if covariates is not None else []
prompt = load_search_prompt(config.root_dir, config.local_search.prompt)
logger.debug("Executing streaming local search query: %s", query)
search_engine = get_local_search_engine(
config=config,
reports=read_indexer_reports(community_reports, communities, community_level),
@ -477,6 +483,7 @@ async def multi_index_local_search(
streaming: bool,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
@ -500,11 +507,9 @@ async def multi_index_local_search(
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
init_loggers(config=config, verbose=verbose)
# Streaming not supported yet
if streaming:
message = "Streaming not yet implemented for multi_index_local_search"
@ -670,6 +675,7 @@ async def multi_index_local_search(
covariates_combined = pd.concat(
covariates_dfs, axis=0, ignore_index=True, sort=False
)
logger.debug("Executing multi-index local search query: %s", query)
result = await local_search(
config,
entities=entities_combined,
@ -687,6 +693,7 @@ async def multi_index_local_search(
# Update the context data by linking index names and community ids
context = update_context_data(result[1], links)
logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore
return (result[0], context)
@ -702,6 +709,7 @@ async def drift_search(
response_type: str,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
@ -721,11 +729,9 @@ async def drift_search(
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
init_loggers(config=config, verbose=verbose)
callbacks = callbacks or []
full_response = ""
context_data = {}
@ -738,6 +744,7 @@ async def drift_search(
local_callbacks.on_context = on_context
callbacks.append(local_callbacks)
logger.debug("Executing drift search query: %s", query)
async for chunk in drift_search_streaming(
config=config,
entities=entities,
@ -751,6 +758,7 @@ async def drift_search(
callbacks=callbacks,
):
full_response += chunk
logger.debug("Query response: %s", truncate(full_response, 400))
return full_response, context_data
@ -766,6 +774,7 @@ def drift_search_streaming(
response_type: str,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> AsyncGenerator:
"""Perform a DRIFT search and return the context data and response.
@ -782,16 +791,14 @@ def drift_search_streaming(
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
init_loggers(config=config, verbose=verbose)
vector_store_args = {}
for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump()
msg = f"Vector Store Args: {redact(vector_store_args)}"
logger.info(msg)
logger.debug(msg)
description_embedding_store = get_embedding_store(
config_args=vector_store_args,
@ -811,6 +818,7 @@ def drift_search_streaming(
config.root_dir, config.drift_search.reduce_prompt
)
logger.debug("Executing streaming drift search query: %s", query)
search_engine = get_drift_search_engine(
config=config,
reports=reports,
@ -840,6 +848,7 @@ async def multi_index_drift_search(
streaming: bool,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
@ -862,11 +871,9 @@ async def multi_index_drift_search(
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
init_loggers(config=config, verbose=verbose)
# Streaming not supported yet
if streaming:
message = "Streaming not yet implemented for multi_drift_search"
@ -1009,6 +1016,7 @@ async def multi_index_drift_search(
text_units_dfs, axis=0, ignore_index=True, sort=False
)
logger.debug("Executing multi-index drift search query: %s", query)
result = await drift_search(
config,
entities=entities_combined,
@ -1029,6 +1037,8 @@ async def multi_index_drift_search(
context[key] = update_context_data(result[1][key], links)
else:
context = result[1]
logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore
return (result[0], context)
@ -1038,6 +1048,7 @@ async def basic_search(
text_units: pd.DataFrame,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
@ -1053,11 +1064,9 @@ async def basic_search(
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
init_loggers(config=config, verbose=verbose)
callbacks = callbacks or []
full_response = ""
context_data = {}
@ -1070,6 +1079,7 @@ async def basic_search(
local_callbacks.on_context = on_context
callbacks.append(local_callbacks)
logger.debug("Executing basic search query: %s", query)
async for chunk in basic_search_streaming(
config=config,
text_units=text_units,
@ -1077,6 +1087,7 @@ async def basic_search(
callbacks=callbacks,
):
full_response += chunk
logger.debug("Query response: %s", truncate(full_response, 400))
return full_response, context_data
@ -1086,6 +1097,7 @@ def basic_search_streaming(
text_units: pd.DataFrame,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> AsyncGenerator:
"""Perform a local search and return the context data and response via a generator.
@ -1098,16 +1110,14 @@ def basic_search_streaming(
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
init_loggers(config=config, verbose=verbose)
vector_store_args = {}
for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump()
msg = f"Vector Store Args: {redact(vector_store_args)}"
logger.info(msg)
logger.debug(msg)
description_embedding_store = get_embedding_store(
config_args=vector_store_args,
@ -1116,6 +1126,7 @@ def basic_search_streaming(
prompt = load_search_prompt(config.root_dir, config.basic_search.prompt)
logger.debug("Executing streaming basic search query: %s", query)
search_engine = get_basic_search_engine(
config=config,
text_units=read_indexer_text_units(text_units),
@ -1134,6 +1145,7 @@ async def multi_index_basic_search(
streaming: bool,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
@ -1151,11 +1163,9 @@ async def multi_index_basic_search(
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
init_loggers(config=config, verbose=verbose)
# Streaming not supported yet
if streaming:
message = "Streaming not yet implemented for multi_basic_search"
@ -1192,6 +1202,7 @@ async def multi_index_basic_search(
text_units_dfs, axis=0, ignore_index=True, sort=False
)
logger.debug("Executing multi-index basic search query: %s", query)
return await basic_search(
config,
text_units=text_units_combined,

View File

@ -1,32 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A logger that emits updates from the indexing engine to the console."""
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks):
"""A logger that writes to a console."""
def error(
self,
message: str,
cause: BaseException | None = None,
stack: str | None = None,
details: dict | None = None,
):
"""Handle when an error occurs."""
print(message, str(cause), stack, details) # noqa T201
def warning(self, message: str, details: dict | None = None):
"""Handle when a warning occurs."""
_print_warning(message)
def log(self, message: str, details: dict | None = None):
"""Handle when a log message is produced."""
print(message, details) # noqa T201
def _print_warning(skk):
print("\033[93m {}\033[00m".format(skk)) # noqa T201

View File

@ -1,78 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A logger that emits updates from the indexing engine to a local file."""
import json
import logging
from io import TextIOWrapper
from pathlib import Path
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
log = logging.getLogger(__name__)
class FileWorkflowCallbacks(NoopWorkflowCallbacks):
"""A logger that writes to a local file."""
_out_stream: TextIOWrapper
def __init__(self, directory: str):
"""Create a new file-based workflow logger."""
Path(directory).mkdir(parents=True, exist_ok=True)
self._out_stream = open( # noqa: PTH123, SIM115
Path(directory) / "logs.json", "a", encoding="utf-8", errors="strict"
)
def error(
self,
message: str,
cause: BaseException | None = None,
stack: str | None = None,
details: dict | None = None,
):
"""Handle when an error occurs."""
self._out_stream.write(
json.dumps(
{
"type": "error",
"data": message,
"stack": stack,
"source": str(cause),
"details": details,
},
indent=4,
ensure_ascii=False,
)
+ "\n"
)
message = f"{message} details={details}"
log.info(message)
def warning(self, message: str, details: dict | None = None):
"""Handle when a warning occurs."""
self._out_stream.write(
json.dumps(
{"type": "warning", "data": message, "details": details},
ensure_ascii=False,
)
+ "\n"
)
_print_warning(message)
def log(self, message: str, details: dict | None = None):
"""Handle when a log message is produced."""
self._out_stream.write(
json.dumps(
{"type": "log", "data": message, "details": details}, ensure_ascii=False
)
+ "\n"
)
message = f"{message} details={details}"
log.info(message)
def _print_warning(skk):
log.warning(skk)

View File

@ -9,13 +9,13 @@ from graphrag.logger.progress import Progress
class NoopWorkflowCallbacks(WorkflowCallbacks):
"""A no-op implementation of WorkflowCallbacks."""
"""A no-op implementation of WorkflowCallbacks that logs all events to standard logging."""
def pipeline_start(self, names: list[str]) -> None:
"""Execute this callback when a the entire pipeline starts."""
"""Execute this callback to signal when the entire pipeline starts."""
def pipeline_end(self, results: list[PipelineRunResult]) -> None:
"""Execute this callback when the entire pipeline ends."""
"""Execute this callback to signal when the entire pipeline ends."""
def workflow_start(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow starts."""
@ -25,18 +25,3 @@ class NoopWorkflowCallbacks(WorkflowCallbacks):
def progress(self, progress: Progress) -> None:
"""Handle when progress occurs."""
def error(
self,
message: str,
cause: BaseException | None = None,
stack: str | None = None,
details: dict | None = None,
) -> None:
"""Handle when an error occurs."""
def warning(self, message: str, details: dict | None = None) -> None:
"""Handle when a warning occurs."""
def log(self, message: str, details: dict | None = None) -> None:
"""Handle when a log message occurs."""

View File

@ -1,42 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A workflow callback manager that emits updates."""
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.logger.base import ProgressLogger
from graphrag.logger.progress import Progress
class ProgressWorkflowCallbacks(NoopWorkflowCallbacks):
"""A callbackmanager that delegates to a ProgressLogger."""
_root_progress: ProgressLogger
_progress_stack: list[ProgressLogger]
def __init__(self, progress: ProgressLogger) -> None:
"""Create a new ProgressWorkflowCallbacks."""
self._progress = progress
self._progress_stack = [progress]
def _pop(self) -> None:
self._progress_stack.pop()
def _push(self, name: str) -> None:
self._progress_stack.append(self._latest.child(name))
@property
def _latest(self) -> ProgressLogger:
return self._progress_stack[-1]
def workflow_start(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow starts."""
self._push(name)
def workflow_end(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow ends."""
self._pop()
def progress(self, progress: Progress) -> None:
"""Handle when progress occurs."""
self._latest(progress)

View File

@ -1,39 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing the pipeline reporter factory."""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
from graphrag.callbacks.blob_workflow_callbacks import BlobWorkflowCallbacks
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
from graphrag.callbacks.file_workflow_callbacks import FileWorkflowCallbacks
from graphrag.config.enums import ReportingType
from graphrag.config.models.reporting_config import ReportingConfig
if TYPE_CHECKING:
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
def create_pipeline_reporter(
config: ReportingConfig | None, root_dir: str | None
) -> WorkflowCallbacks:
"""Create a logger for the given pipeline config."""
config = config or ReportingConfig(base_dir="logs", type=ReportingType.file)
match config.type:
case ReportingType.file:
return FileWorkflowCallbacks(
str(Path(root_dir or "") / (config.base_dir or ""))
)
case ReportingType.console:
return ConsoleWorkflowCallbacks()
case ReportingType.blob:
return BlobWorkflowCallbacks(
config.connection_string,
config.container_name,
base_dir=config.base_dir,
storage_account_blob_url=config.storage_account_blob_url,
)

View File

@ -35,21 +35,3 @@ class WorkflowCallbacks(Protocol):
def progress(self, progress: Progress) -> None:
"""Handle when progress occurs."""
...
def error(
self,
message: str,
cause: BaseException | None = None,
stack: str | None = None,
details: dict | None = None,
) -> None:
"""Handle when an error occurs."""
...
def warning(self, message: str, details: dict | None = None) -> None:
"""Handle when a warning occurs."""
...
def log(self, message: str, details: dict | None = None) -> None:
"""Handle when a log message occurs."""
...

View File

@ -50,27 +50,3 @@ class WorkflowCallbacksManager(WorkflowCallbacks):
for callback in self._callbacks:
if hasattr(callback, "progress"):
callback.progress(progress)
def error(
self,
message: str,
cause: BaseException | None = None,
stack: str | None = None,
details: dict | None = None,
) -> None:
"""Handle when an error occurs."""
for callback in self._callbacks:
if hasattr(callback, "error"):
callback.error(message, cause, stack, details)
def warning(self, message: str, details: dict | None = None) -> None:
"""Handle when a warning occurs."""
for callback in self._callbacks:
if hasattr(callback, "warning"):
callback.warning(message, details)
def log(self, message: str, details: dict | None = None) -> None:
"""Handle when a log message occurs."""
for callback in self._callbacks:
if hasattr(callback, "log"):
callback.log(message, details)

View File

@ -10,49 +10,26 @@ import warnings
from pathlib import Path
import graphrag.api as api
from graphrag.config.enums import CacheType, IndexingMethod
from graphrag.config.enums import CacheType, IndexingMethod, ReportingType
from graphrag.config.load_config import load_config
from graphrag.config.logging import enable_logging_with_config
from graphrag.index.validate_config import validate_config_names
from graphrag.logger.base import ProgressLogger
from graphrag.logger.factory import LoggerFactory, LoggerType
from graphrag.utils.cli import redact
# Ignore warnings from numba
warnings.filterwarnings("ignore", message=".*NumbaDeprecationWarning.*")
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
def _logger(logger: ProgressLogger):
def info(msg: str, verbose: bool = False):
log.info(msg)
if verbose:
logger.info(msg)
def error(msg: str, verbose: bool = False):
log.error(msg)
if verbose:
logger.error(msg)
def success(msg: str, verbose: bool = False):
log.info(msg)
if verbose:
logger.success(msg)
return info, error, success
def _register_signal_handlers(logger: ProgressLogger):
def _register_signal_handlers():
import signal
def handle_signal(signum, _):
# Handle the signal here
logger.info(f"Received signal {signum}, exiting...") # noqa: G004
logger.dispose()
logger.debug(f"Received signal {signum}, exiting...") # noqa: G004
for task in asyncio.all_tasks():
task.cancel()
logger.info("All tasks cancelled. Exiting...")
logger.debug("All tasks cancelled. Exiting...")
# Register signal handlers for SIGINT and SIGHUP
signal.signal(signal.SIGINT, handle_signal)
@ -67,7 +44,6 @@ def index_cli(
verbose: bool,
memprofile: bool,
cache: bool,
logger: LoggerType,
config_filepath: Path | None,
dry_run: bool,
skip_validation: bool,
@ -87,7 +63,6 @@ def index_cli(
verbose=verbose,
memprofile=memprofile,
cache=cache,
logger=logger,
dry_run=dry_run,
skip_validation=skip_validation,
)
@ -99,7 +74,6 @@ def update_cli(
verbose: bool,
memprofile: bool,
cache: bool,
logger: LoggerType,
config_filepath: Path | None,
skip_validation: bool,
output_dir: Path | None,
@ -120,7 +94,6 @@ def update_cli(
verbose=verbose,
memprofile=memprofile,
cache=cache,
logger=logger,
dry_run=False,
skip_validation=skip_validation,
)
@ -133,39 +106,47 @@ def _run_index(
verbose,
memprofile,
cache,
logger,
dry_run,
skip_validation,
):
progress_logger = LoggerFactory().create_logger(logger)
info, error, success = _logger(progress_logger)
# Configure the root logger with the specified log level
from graphrag.logger.standard_logging import init_loggers
# Initialize loggers and reporting config
init_loggers(
config=config,
root_dir=str(config.root_dir) if config.root_dir else None,
verbose=verbose,
)
if not cache:
config.cache.type = CacheType.none
enabled_logging, log_path = enable_logging_with_config(config, verbose)
if enabled_logging:
info(f"Logging enabled at {log_path}", True)
# Log the configuration details
if config.reporting.type == ReportingType.file:
log_dir = Path(config.root_dir or "") / (config.reporting.base_dir or "")
log_path = log_dir / "logs.txt"
logger.info("Logging enabled at %s", log_path)
else:
info(
f"Logging not enabled for config {redact(config.model_dump())}",
True,
logger.info(
"Logging not enabled for config %s",
redact(config.model_dump()),
)
if not skip_validation:
validate_config_names(progress_logger, config)
validate_config_names(config)
info(f"Starting pipeline run. {dry_run=}", verbose)
info(
f"Using default configuration: {redact(config.model_dump())}",
verbose,
logger.info("Starting pipeline run. %s", dry_run)
logger.info(
"Using default configuration: %s",
redact(config.model_dump()),
)
if dry_run:
info("Dry run complete, exiting...", True)
logger.info("Dry run complete, exiting...", True)
sys.exit(0)
_register_signal_handlers(progress_logger)
_register_signal_handlers()
outputs = asyncio.run(
api.build_index(
@ -173,19 +154,17 @@ def _run_index(
method=method,
is_update_run=is_update_run,
memory_profile=memprofile,
progress_logger=progress_logger,
)
)
encountered_errors = any(
output.errors and len(output.errors) > 0 for output in outputs
)
progress_logger.stop()
if encountered_errors:
error(
"Errors occurred during the pipeline run, see logs for more details.", True
logger.error(
"Errors occurred during the pipeline run, see logs for more details."
)
else:
success("All workflows completed successfully.", True)
logger.info("All workflows completed successfully.")
sys.exit(1 if encountered_errors else 0)

View File

@ -3,10 +3,10 @@
"""CLI implementation of the initialization subcommand."""
import logging
from pathlib import Path
from graphrag.config.init_content import INIT_DOTENV, INIT_YAML
from graphrag.logger.factory import LoggerFactory, LoggerType
from graphrag.prompts.index.community_report import (
COMMUNITY_REPORT_PROMPT,
)
@ -31,6 +31,8 @@ from graphrag.prompts.query.global_search_reduce_system_prompt import (
from graphrag.prompts.query.local_search_system_prompt import LOCAL_SEARCH_SYSTEM_PROMPT
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
logger = logging.getLogger(__name__)
def initialize_project_at(path: Path, force: bool) -> None:
"""
@ -48,8 +50,7 @@ def initialize_project_at(path: Path, force: bool) -> None:
ValueError
If the project already exists and force is False.
"""
progress_logger = LoggerFactory().create_logger(LoggerType.RICH)
progress_logger.info(f"Initializing project at {path}") # noqa: G004
logger.info("Initializing project at %s", path)
root = Path(path)
if not root.exists():
root.mkdir(parents=True, exist_ok=True)

View File

@ -12,7 +12,6 @@ import typer
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import IndexingMethod, SearchMethod
from graphrag.logger.types import LoggerType
from graphrag.prompt_tune.defaults import LIMIT, MAX_TOKEN_COUNT, N_SUBSET_MAX, K
from graphrag.prompt_tune.types import DocSelectionType
@ -157,11 +156,6 @@ def _index_cli(
"--memprofile",
help="Run the indexing pipeline with memory profiling",
),
logger: LoggerType = typer.Option(
LoggerType.RICH.value,
"--logger",
help="The progress logger to use.",
),
dry_run: bool = typer.Option(
False,
"--dry-run",
@ -201,7 +195,6 @@ def _index_cli(
verbose=verbose,
memprofile=memprofile,
cache=cache,
logger=LoggerType(logger),
config_filepath=config,
dry_run=dry_run,
skip_validation=skip_validation,
@ -250,11 +243,6 @@ def _update_cli(
"--memprofile",
help="Run the indexing pipeline with memory profiling.",
),
logger: LoggerType = typer.Option(
LoggerType.RICH.value,
"--logger",
help="The progress logger to use.",
),
cache: bool = typer.Option(
True,
"--cache/--no-cache",
@ -290,7 +278,6 @@ def _update_cli(
verbose=verbose,
memprofile=memprofile,
cache=cache,
logger=LoggerType(logger),
config_filepath=config,
skip_validation=skip_validation,
output_dir=output,
@ -327,11 +314,6 @@ def _prompt_tune_cli(
"-v",
help="Run the prompt tuning pipeline with verbose logging.",
),
logger: LoggerType = typer.Option(
LoggerType.RICH.value,
"--logger",
help="The progress logger to use.",
),
domain: str | None = typer.Option(
None,
"--domain",
@ -413,7 +395,6 @@ def _prompt_tune_cli(
config=config,
domain=domain,
verbose=verbose,
logger=logger,
selection_method=selection_method,
limit=limit,
max_tokens=max_tokens,
@ -453,6 +434,12 @@ def _query_cli(
readable=True,
autocompletion=CONFIG_AUTOCOMPLETE,
),
verbose: bool = typer.Option(
False,
"--verbose",
"-v",
help="Run the query with verbose logging.",
),
data: Path | None = typer.Option(
None,
"--data",
@ -520,6 +507,7 @@ def _query_cli(
response_type=response_type,
streaming=streaming,
query=query,
verbose=verbose,
)
case SearchMethod.GLOBAL:
run_global_search(
@ -531,6 +519,7 @@ def _query_cli(
response_type=response_type,
streaming=streaming,
query=query,
verbose=verbose,
)
case SearchMethod.DRIFT:
run_drift_search(
@ -541,6 +530,7 @@ def _query_cli(
streaming=streaming,
response_type=response_type,
query=query,
verbose=verbose,
)
case SearchMethod.BASIC:
run_basic_search(
@ -549,6 +539,7 @@ def _query_cli(
root_dir=root,
streaming=streaming,
query=query,
verbose=verbose,
)
case _:
raise ValueError(INVALID_METHOD_ERROR)

View File

@ -3,13 +3,12 @@
"""CLI implementation of the prompt-tune subcommand."""
import logging
from pathlib import Path
import graphrag.api as api
from graphrag.cli.index import _logger
from graphrag.config.enums import ReportingType
from graphrag.config.load_config import load_config
from graphrag.config.logging import enable_logging_with_config
from graphrag.logger.factory import LoggerFactory, LoggerType
from graphrag.prompt_tune.generator.community_report_summarization import (
COMMUNITY_SUMMARIZATION_FILENAME,
)
@ -21,13 +20,14 @@ from graphrag.prompt_tune.generator.extract_graph_prompt import (
)
from graphrag.utils.cli import redact
logger = logging.getLogger(__name__)
async def prompt_tune(
root: Path,
config: Path | None,
domain: str | None,
verbose: bool,
logger: LoggerType,
selection_method: api.DocSelectionType,
limit: int,
max_tokens: int,
@ -47,8 +47,7 @@ async def prompt_tune(
- config: The configuration file.
- root: The root directory.
- domain: The domain to map the input documents to.
- verbose: Whether to enable verbose logging.
- logger: The logger to use.
- verbose: Enable verbose logging.
- selection_method: The chunk selection method.
- limit: The limit of chunks to load.
- max_tokens: The maximum number of tokens to use on entity extraction prompts.
@ -70,23 +69,29 @@ async def prompt_tune(
if overlap != graph_config.chunks.overlap:
graph_config.chunks.overlap = overlap
progress_logger = LoggerFactory().create_logger(logger)
info, error, success = _logger(progress_logger)
# configure the root logger with the specified log level
from graphrag.logger.standard_logging import init_loggers
enabled_logging, log_path = enable_logging_with_config(
graph_config, verbose, filename="prompt-tune.log"
# initialize loggers with config
init_loggers(
config=graph_config,
root_dir=str(root_path),
verbose=verbose,
)
if enabled_logging:
info(f"Logging enabled at {log_path}", verbose)
# log the configuration details
if graph_config.reporting.type == ReportingType.file:
log_dir = Path(root_path) / (graph_config.reporting.base_dir or "")
log_path = log_dir / "logs.txt"
logger.info("Logging enabled at %s", log_path)
else:
info(
f"Logging not enabled for config {redact(graph_config.model_dump())}",
verbose,
logger.info(
"Logging not enabled for config %s",
redact(graph_config.model_dump()),
)
prompts = await api.generate_indexing_prompts(
config=graph_config,
logger=progress_logger,
chunk_size=chunk_size,
overlap=overlap,
limit=limit,
@ -102,20 +107,20 @@ async def prompt_tune(
output_path = output.resolve()
if output_path:
info(f"Writing prompts to {output_path}")
logger.info("Writing prompts to %s", output_path)
output_path.mkdir(parents=True, exist_ok=True)
extract_graph_prompt_path = output_path / EXTRACT_GRAPH_FILENAME
entity_summarization_prompt_path = output_path / ENTITY_SUMMARIZATION_FILENAME
community_summarization_prompt_path = (
output_path / COMMUNITY_SUMMARIZATION_FILENAME
)
# Write files to output path
# write files to output path
with extract_graph_prompt_path.open("wb") as file:
file.write(prompts[0].encode(encoding="utf-8", errors="strict"))
with entity_summarization_prompt_path.open("wb") as file:
file.write(prompts[1].encode(encoding="utf-8", errors="strict"))
with community_summarization_prompt_path.open("wb") as file:
file.write(prompts[2].encode(encoding="utf-8", errors="strict"))
success(f"Prompts written to {output_path}")
logger.info("Prompts written to %s", output_path)
else:
error("No output path provided. Skipping writing prompts.")
logger.error("No output path provided. Skipping writing prompts.")

View File

@ -4,6 +4,7 @@
"""CLI implementation of the query subcommand."""
import asyncio
import logging
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any
@ -12,14 +13,14 @@ import graphrag.api as api
from graphrag.callbacks.noop_query_callbacks import NoopQueryCallbacks
from graphrag.config.load_config import load_config
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.utils.api import create_storage_from_config
from graphrag.utils.storage import load_table_from_storage, storage_has_table
if TYPE_CHECKING:
import pandas as pd
logger = PrintProgressLogger("")
# Initialize standard logger
logger = logging.getLogger(__name__)
def run_global_search(
@ -31,6 +32,7 @@ def run_global_search(
response_type: str,
streaming: bool,
query: str,
verbose: bool,
):
"""Perform a global search with a given query.
@ -59,10 +61,10 @@ def run_global_search(
final_community_reports_list = dataframe_dict["community_reports"]
index_names = dataframe_dict["index_names"]
logger.success(
f"Running Multi-index Global Search: {dataframe_dict['index_names']}"
logger.info(
"Running multi-index global search on indexes: %s",
dataframe_dict["index_names"],
)
response, context_data = asyncio.run(
api.multi_index_global_search(
config=config,
@ -75,9 +77,12 @@ def run_global_search(
response_type=response_type,
streaming=streaming,
query=query,
verbose=verbose,
)
)
logger.success(f"Global Search Response:\n{response}")
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
logger.info("Query Response:\n%s", response)
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data
@ -110,6 +115,7 @@ def run_global_search(
response_type=response_type,
query=query,
callbacks=[callbacks],
verbose=verbose,
):
full_response += stream_chunk
print(stream_chunk, end="") # noqa: T201
@ -129,9 +135,12 @@ def run_global_search(
dynamic_community_selection=dynamic_community_selection,
response_type=response_type,
query=query,
verbose=verbose,
)
)
logger.success(f"Global Search Response:\n{response}")
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
logger.info("Global Search Response:\n%s", response)
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data
@ -145,6 +154,7 @@ def run_local_search(
response_type: str,
streaming: bool,
query: str,
verbose: bool,
):
"""Perform a local search with a given query.
@ -178,8 +188,9 @@ def run_local_search(
final_relationships_list = dataframe_dict["relationships"]
index_names = dataframe_dict["index_names"]
logger.success(
f"Running Multi-index Local Search: {dataframe_dict['index_names']}"
logger.info(
"Running multi-index local search on indexes: %s",
dataframe_dict["index_names"],
)
# If any covariates tables are missing from any index, set the covariates list to None
@ -202,9 +213,12 @@ def run_local_search(
response_type=response_type,
streaming=streaming,
query=query,
verbose=verbose,
)
)
logger.success(f"Local Search Response:\n{response}")
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
logger.info("Local Search Response:\n%s", response)
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data
@ -242,6 +256,7 @@ def run_local_search(
response_type=response_type,
query=query,
callbacks=[callbacks],
verbose=verbose,
):
full_response += stream_chunk
print(stream_chunk, end="") # noqa: T201
@ -263,9 +278,12 @@ def run_local_search(
community_level=community_level,
response_type=response_type,
query=query,
verbose=verbose,
)
)
logger.success(f"Local Search Response:\n{response}")
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
logger.info("Local Search Response:\n%s", response)
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data
@ -279,6 +297,7 @@ def run_drift_search(
response_type: str,
streaming: bool,
query: str,
verbose: bool,
):
"""Perform a local search with a given query.
@ -310,8 +329,9 @@ def run_drift_search(
final_relationships_list = dataframe_dict["relationships"]
index_names = dataframe_dict["index_names"]
logger.success(
f"Running Multi-index Drift Search: {dataframe_dict['index_names']}"
logger.info(
"Running multi-index drift search on indexes: %s",
dataframe_dict["index_names"],
)
response, context_data = asyncio.run(
@ -327,9 +347,12 @@ def run_drift_search(
response_type=response_type,
streaming=streaming,
query=query,
verbose=verbose,
)
)
logger.success(f"DRIFT Search Response:\n{response}")
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
logger.info("DRIFT Search Response:\n%s", response)
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data
@ -365,6 +388,7 @@ def run_drift_search(
response_type=response_type,
query=query,
callbacks=[callbacks],
verbose=verbose,
):
full_response += stream_chunk
print(stream_chunk, end="") # noqa: T201
@ -386,9 +410,12 @@ def run_drift_search(
community_level=community_level,
response_type=response_type,
query=query,
verbose=verbose,
)
)
logger.success(f"DRIFT Search Response:\n{response}")
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
logger.info("DRIFT Search Response:\n%s", response)
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data
@ -400,6 +427,7 @@ def run_basic_search(
root_dir: Path,
streaming: bool,
query: str,
verbose: bool,
):
"""Perform a basics search with a given query.
@ -423,8 +451,9 @@ def run_basic_search(
final_text_units_list = dataframe_dict["text_units"]
index_names = dataframe_dict["index_names"]
logger.success(
f"Running Multi-index Basic Search: {dataframe_dict['index_names']}"
logger.info(
"Running multi-index basic search on indexes: %s",
dataframe_dict["index_names"],
)
response, context_data = asyncio.run(
@ -434,9 +463,12 @@ def run_basic_search(
index_names=index_names,
streaming=streaming,
query=query,
verbose=verbose,
)
)
logger.success(f"Basic Search Response:\n{response}")
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
logger.info("Basic Search Response:\n%s", response)
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data
@ -461,6 +493,8 @@ def run_basic_search(
config=config,
text_units=final_text_units,
query=query,
callbacks=[callbacks],
verbose=verbose,
):
full_response += stream_chunk
print(stream_chunk, end="") # noqa: T201
@ -475,9 +509,12 @@ def run_basic_search(
config=config,
text_units=final_text_units,
query=query,
verbose=verbose,
)
)
logger.success(f"Basic Search Response:\n{response}")
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
logger.info("Basic Search Response:\n%s", response)
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data

View File

@ -5,7 +5,7 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal
from typing import ClassVar, Literal
from graphrag.config.embeddings import default_embeddings
from graphrag.config.enums import (
@ -54,7 +54,7 @@ class BasicSearchDefaults:
class CacheDefaults:
"""Default values for cache."""
type = CacheType.file
type: ClassVar[CacheType] = CacheType.file
base_dir: str = "cache"
connection_string: None = None
container_name: None = None
@ -69,7 +69,7 @@ class ChunksDefaults:
size: int = 1200
overlap: int = 100
group_by_columns: list[str] = field(default_factory=lambda: ["id"])
strategy = ChunkStrategyType.tokens
strategy: ClassVar[ChunkStrategyType] = ChunkStrategyType.tokens
encoding_model: str = "cl100k_base"
prepend_metadata: bool = False
chunk_size_includes_metadata: bool = False
@ -119,8 +119,8 @@ class DriftSearchDefaults:
local_search_temperature: float = 0
local_search_top_p: float = 1
local_search_n: int = 1
local_search_llm_max_gen_tokens = None
local_search_llm_max_gen_completion_tokens = None
local_search_llm_max_gen_tokens: int | None = None
local_search_llm_max_gen_completion_tokens: int | None = None
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
@ -183,7 +183,9 @@ class ExtractGraphDefaults:
class TextAnalyzerDefaults:
"""Default values for text analyzer."""
extractor_type = NounPhraseExtractorType.RegexEnglish
extractor_type: ClassVar[NounPhraseExtractorType] = (
NounPhraseExtractorType.RegexEnglish
)
model_name: str = "en_core_web_md"
max_word_length: int = 15
word_delimiter: str = " "
@ -257,7 +259,7 @@ class InputDefaults:
"""Default values for input."""
storage: InputStorageDefaults = field(default_factory=InputStorageDefaults)
file_type = InputFileType.text
file_type: ClassVar[InputFileType] = InputFileType.text
encoding: str = "utf-8"
file_pattern: str = ""
file_filter: None = None
@ -271,7 +273,7 @@ class LanguageModelDefaults:
"""Default values for language model."""
api_key: None = None
auth_type = AuthType.APIKey
auth_type: ClassVar[AuthType] = AuthType.APIKey
encoding_model: str = ""
max_tokens: int | None = None
temperature: float = 0
@ -338,7 +340,7 @@ class PruneGraphDefaults:
class ReportingDefaults:
"""Default values for reporting."""
type = ReportingType.file
type: ClassVar[ReportingType] = ReportingType.file
base_dir: str = "logs"
connection_string: None = None
container_name: None = None
@ -383,7 +385,7 @@ class UpdateIndexOutputDefaults(StorageDefaults):
class VectorStoreDefaults:
"""Default values for vector stores."""
type = VectorStoreType.LanceDB.value
type: ClassVar[str] = VectorStoreType.LanceDB.value
db_uri: str = str(Path(DEFAULT_OUTPUT_BASE_DIR) / "lancedb")
container_name: str = "default"
overwrite: bool = True

View File

@ -64,8 +64,6 @@ class ReportingType(str, Enum):
file = "file"
"""The file reporting configuration type."""
console = "console"
"""The console reporting configuration type."""
blob = "blob"
"""The blob reporting configuration type."""

View File

@ -1,61 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Logging utilities. A unified way for enabling logging."""
import logging
from pathlib import Path
from graphrag.config.enums import ReportingType
from graphrag.config.models.graph_rag_config import GraphRagConfig
def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None:
"""Enable logging to a file.
Parameters
----------
log_filepath : str | Path
The path to the log file.
verbose : bool, default=False
Whether to log debug messages.
"""
log_filepath = Path(log_filepath)
log_filepath.parent.mkdir(parents=True, exist_ok=True)
log_filepath.touch(exist_ok=True)
logging.basicConfig(
filename=log_filepath,
filemode="a",
format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
datefmt="%H:%M:%S",
level=logging.DEBUG if verbose else logging.INFO,
)
def enable_logging_with_config(
config: GraphRagConfig, verbose: bool = False, filename: str = "indexing-engine.log"
) -> tuple[bool, str]:
"""Enable logging to a file based on the config.
Parameters
----------
config : GraphRagConfig
The configuration.
timestamp_value : str
The timestamp value representing the directory to place the log files.
verbose : bool, default=False
Whether to log debug messages.
Returns
-------
tuple[bool, str]
A tuple of a boolean indicating if logging was enabled and the path to the log file.
(False, "") if logging was not enabled.
(True, str) if logging was enabled.
"""
if config.reporting.type == ReportingType.file:
log_path = Path(config.reporting.base_dir) / filename
enable_logging(log_path, verbose)
return (True, str(log_path))
return (False, "")

View File

@ -9,17 +9,17 @@ from pathlib import Path
from dotenv import dotenv_values
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
def read_dotenv(root: str) -> None:
"""Read a .env file in the given root path."""
env_path = Path(root) / ".env"
if env_path.exists():
log.info("Loading pipeline .env file")
logger.info("Loading pipeline .env file")
env_config = dotenv_values(f"{env_path}")
for key, value in env_config.items():
if key not in os.environ:
os.environ[key] = value or ""
else:
log.info("No .env file found at %s", root)
logger.info("No .env file found at %s", root)

View File

@ -10,19 +10,17 @@ import pandas as pd
from graphrag.config.models.input_config import InputConfig
from graphrag.index.input.util import load_files, process_data_columns
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
async def load_csv(
config: InputConfig,
progress: ProgressLogger | None,
storage: PipelineStorage,
) -> pd.DataFrame:
"""Load csv inputs from a directory."""
log.info("Loading csv files from %s", config.storage.base_dir)
logger.info("Loading csv files from %s", config.storage.base_dir)
async def load_file(path: str, group: dict | None) -> pd.DataFrame:
if group is None:
@ -42,4 +40,4 @@ async def load_csv(
return data
return await load_files(load_file, config, storage, progress)
return await load_files(load_file, config, storage)

View File

@ -14,11 +14,9 @@ from graphrag.config.models.input_config import InputConfig
from graphrag.index.input.csv import load_csv
from graphrag.index.input.json import load_json
from graphrag.index.input.text import load_text
from graphrag.logger.base import ProgressLogger
from graphrag.logger.null_progress import NullProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
InputFileType.text: load_text,
InputFileType.csv: load_csv,
@ -29,17 +27,14 @@ loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
async def create_input(
config: InputConfig,
storage: PipelineStorage,
progress_reporter: ProgressLogger | None = None,
) -> pd.DataFrame:
"""Instantiate input data for a pipeline."""
progress_reporter = progress_reporter or NullProgressLogger()
logger.info("loading input from root_dir=%s", config.storage.base_dir)
if config.file_type in loaders:
progress = progress_reporter.child(
f"Loading Input ({config.file_type})", transient=False
)
logger.info("Loading Input %s", config.file_type)
loader = loaders[config.file_type]
result = await loader(config, progress, storage)
result = await loader(config, storage)
# Convert metadata columns to strings and collapse them into a JSON object
if config.metadata:
if all(col in result.columns for col in config.metadata):

View File

@ -10,19 +10,17 @@ import pandas as pd
from graphrag.config.models.input_config import InputConfig
from graphrag.index.input.util import load_files, process_data_columns
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
async def load_json(
config: InputConfig,
progress: ProgressLogger | None,
storage: PipelineStorage,
) -> pd.DataFrame:
"""Load json inputs from a directory."""
log.info("Loading json files from %s", config.storage.base_dir)
logger.info("Loading json files from %s", config.storage.base_dir)
async def load_file(path: str, group: dict | None) -> pd.DataFrame:
if group is None:
@ -46,4 +44,4 @@ async def load_json(
return data
return await load_files(load_file, config, storage, progress)
return await load_files(load_file, config, storage)

View File

@ -11,15 +11,13 @@ import pandas as pd
from graphrag.config.models.input_config import InputConfig
from graphrag.index.input.util import load_files
from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
async def load_text(
config: InputConfig,
progress: ProgressLogger | None,
storage: PipelineStorage,
) -> pd.DataFrame:
"""Load text inputs from a directory."""
@ -34,4 +32,4 @@ async def load_text(
new_item["creation_date"] = await storage.get_creation_date(path)
return pd.DataFrame([new_item])
return await load_files(load_file, config, storage, progress)
return await load_files(load_file, config, storage)

View File

@ -11,23 +11,20 @@ import pandas as pd
from graphrag.config.models.input_config import InputConfig
from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
async def load_files(
loader: Any,
config: InputConfig,
storage: PipelineStorage,
progress: ProgressLogger | None,
) -> pd.DataFrame:
"""Load files from storage and apply a loader function."""
files = list(
storage.find(
re.compile(config.file_pattern),
progress=progress,
file_filter=config.file_filter,
)
)
@ -42,17 +39,17 @@ async def load_files(
try:
files_loaded.append(await loader(file, group))
except Exception as e: # noqa: BLE001 (catching Exception is fine here)
log.warning("Warning! Error loading file %s. Skipping...", file)
log.warning("Error: %s", e)
logger.warning("Warning! Error loading file %s. Skipping...", file)
logger.warning("Error: %s", e)
log.info(
logger.info(
"Found %d %s files, loading %d", len(files), config.file_type, len(files_loaded)
)
result = pd.concat(files_loaded)
total_files_log = (
f"Total number of unfiltered {config.file_type} rows: {len(result)}"
)
log.info(total_files_log)
logger.info(total_files_log)
return result
@ -66,7 +63,7 @@ def process_data_columns(
)
if config.text_column is not None and "text" not in documents.columns:
if config.text_column not in documents.columns:
log.warning(
logger.warning(
"text_column %s not found in csv file %s",
config.text_column,
path,
@ -75,7 +72,7 @@ def process_data_columns(
documents["text"] = documents.apply(lambda x: x[config.text_column], axis=1)
if config.title_column is not None:
if config.title_column not in documents.columns:
log.warning(
logger.warning(
"title_column %s not found in csv file %s",
config.title_column,
path,

View File

@ -65,6 +65,7 @@ async def _extract_nodes(
extract,
num_threads=num_threads,
async_type=AsyncType.Threaded,
progress_msg="extract noun phrases progress: ",
)
noun_node_df = text_unit_df.explode("noun_phrases")

View File

@ -8,7 +8,7 @@ from abc import ABCMeta, abstractmethod
import spacy
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
class BaseNounPhraseExtractor(metaclass=ABCMeta):
@ -54,7 +54,7 @@ class BaseNounPhraseExtractor(metaclass=ABCMeta):
return spacy.load(model_name, exclude=exclude)
except OSError:
msg = f"Model `{model_name}` not found. Attempting to download..."
log.info(msg)
logger.info(msg)
from spacy.cli.download import download
download(model_name)

View File

@ -13,7 +13,7 @@ from graphrag.index.utils.stable_lcc import stable_largest_connected_component
Communities = list[tuple[int, int, int, list[str]]]
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
def cluster_graph(
@ -24,7 +24,7 @@ def cluster_graph(
) -> Communities:
"""Apply a hierarchical clustering algorithm to a graph."""
if len(graph.nodes) == 0:
log.warning("Graph has no nodes")
logger.warning("Graph has no nodes")
return []
node_id_to_community_map, parent_mapping = _compute_leiden_communities(

View File

@ -17,7 +17,7 @@ from graphrag.index.operations.embed_text.strategies.typing import TextEmbedding
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
from graphrag.vector_stores.factory import VectorStoreFactory
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# Per Azure OpenAI Limits
# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference
@ -141,7 +141,14 @@ async def _text_embed_with_vector_store(
all_results = []
num_total_batches = (input.shape[0] + insert_batch_size - 1) // insert_batch_size
while insert_batch_size * i < input.shape[0]:
logger.info(
"uploading text embeddings batch %d/%d of size %d to vector store",
i + 1,
num_total_batches,
insert_batch_size,
)
batch = input.iloc[insert_batch_size * i : insert_batch_size * (i + 1)]
texts: list[str] = batch[embed_column].to_numpy().tolist()
titles: list[str] = batch[title].to_numpy().tolist()
@ -195,7 +202,7 @@ def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str:
collection_name = create_collection_name(container_name, embedding_name)
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}"
log.info(msg)
logger.info(msg)
return collection_name

View File

@ -21,7 +21,9 @@ async def run( # noqa RUF029 async is required for interface
) -> TextEmbeddingResult:
"""Run the Claim extraction chain."""
input = input if isinstance(input, Iterable) else [input]
ticker = progress_ticker(callbacks.progress, len(input))
ticker = progress_ticker(
callbacks.progress, len(input), description="generate embeddings progress: "
)
return TextEmbeddingResult(
embeddings=[_embed_text(cache, text, ticker) for text in input]
)

View File

@ -19,7 +19,7 @@ from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.logger.progress import ProgressTicker, progress_ticker
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
async def run(
@ -54,7 +54,7 @@ async def run(
batch_max_tokens,
splitter,
)
log.info(
logger.info(
"embedding %d inputs via %d snippets using %d batches. max_batch_size=%d, batch_max_tokens=%d",
len(input),
len(texts),
@ -62,7 +62,11 @@ async def run(
batch_size,
batch_max_tokens,
)
ticker = progress_ticker(callbacks.progress, len(text_batches))
ticker = progress_ticker(
callbacks.progress,
len(text_batches),
description="generate embeddings progress: ",
)
# Embed each chunk of snippets
embeddings = await _execute(model, text_batches, ticker, semaphore)

View File

@ -20,7 +20,7 @@ from graphrag.prompts.index.extract_claims import (
DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
@dataclass
@ -119,7 +119,7 @@ class ClaimExtractor:
]
source_doc_map[document_id] = text
except Exception as e:
log.exception("error extracting claim")
logger.exception("error extracting claim")
self._on_error(
e,
traceback.format_exc(),

View File

@ -23,7 +23,7 @@ from graphrag.index.operations.extract_covariates.typing import (
from graphrag.index.utils.derive_from_rows import derive_from_rows
from graphrag.language_model.manager import ModelManager
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
@ -41,7 +41,7 @@ async def extract_covariates(
num_threads: int = 4,
):
"""Extract claims from a piece of text."""
log.debug("extract_covariates strategy=%s", strategy)
logger.debug("extract_covariates strategy=%s", strategy)
if entity_types is None:
entity_types = DEFAULT_ENTITY_TYPES
@ -71,6 +71,7 @@ async def extract_covariates(
callbacks,
async_type=async_mode,
num_threads=num_threads,
progress_msg="extract covariates progress: ",
)
return pd.DataFrame([item for row in results for item in row or []])
@ -110,8 +111,8 @@ async def run_extract_claims(
model_invoker=llm,
extraction_prompt=extraction_prompt,
max_gleanings=max_gleanings,
on_error=lambda e, s, d: (
callbacks.error("Claim Extraction Error", e, s, d) if callbacks else None
on_error=lambda e, s, d: logger.error(
"Claim Extraction Error", exc_info=e, extra={"stack": s, "details": d}
),
)

View File

@ -18,7 +18,7 @@ from graphrag.index.operations.extract_graph.typing import (
)
from graphrag.index.utils.derive_from_rows import derive_from_rows
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
@ -36,7 +36,7 @@ async def extract_graph(
num_threads: int = 4,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Extract a graph from a piece of text using a language model."""
log.debug("entity_extract strategy=%s", strategy)
logger.debug("entity_extract strategy=%s", strategy)
if entity_types is None:
entity_types = DEFAULT_ENTITY_TYPES
strategy = strategy or {}
@ -54,7 +54,6 @@ async def extract_graph(
result = await strategy_exec(
[Document(text=text, id=id)],
entity_types,
callbacks,
cache,
strategy_config,
)
@ -67,6 +66,7 @@ async def extract_graph(
callbacks,
async_type=async_mode,
num_threads=num_threads,
progress_msg="extract graph progress: ",
)
entity_dfs = []

View File

@ -27,7 +27,7 @@ DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
@dataclass
@ -119,7 +119,7 @@ class GraphExtractor:
source_doc_map[doc_index] = text
all_records[doc_index] = result
except Exception as e:
log.exception("error extracting graph")
logger.exception("error extracting graph")
self._on_error(
e,
traceback.format_exc(),

View File

@ -3,10 +3,11 @@
"""A module containing run_graph_intelligence, run_extract_graph and _create_text_splitter methods to run graph intelligence."""
import logging
import networkx as nx
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.index.operations.extract_graph.graph_extractor import GraphExtractor
@ -19,11 +20,12 @@ from graphrag.index.operations.extract_graph.typing import (
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol.base import ChatModel
logger = logging.getLogger(__name__)
async def run_graph_intelligence(
docs: list[Document],
entity_types: EntityTypes,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
args: StrategyConfig,
) -> EntityExtractionResult:
@ -34,18 +36,16 @@ async def run_graph_intelligence(
name="extract_graph",
model_type=llm_config.type,
config=llm_config,
callbacks=callbacks,
cache=cache,
)
return await run_extract_graph(llm, docs, entity_types, callbacks, args)
return await run_extract_graph(llm, docs, entity_types, args)
async def run_extract_graph(
model: ChatModel,
docs: list[Document],
entity_types: EntityTypes,
callbacks: WorkflowCallbacks | None,
args: StrategyConfig,
) -> EntityExtractionResult:
"""Run the entity extraction chain."""
@ -61,8 +61,8 @@ async def run_extract_graph(
model_invoker=model,
prompt=extraction_prompt,
max_gleanings=max_gleanings,
on_error=lambda e, s, d: (
callbacks.error("Entity Extraction Error", e, s, d) if callbacks else None
on_error=lambda e, s, d: logger.error(
"Entity Extraction Error", exc_info=e, extra={"stack": s, "details": d}
),
)
text_list = [doc.text.strip() for doc in docs]

View File

@ -11,7 +11,6 @@ from typing import Any
import networkx as nx
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
ExtractedEntity = dict[str, Any]
ExtractedRelationship = dict[str, Any]
@ -40,7 +39,6 @@ EntityExtractStrategy = Callable[
[
list[Document],
EntityTypes,
WorkflowCallbacks,
PipelineCache,
StrategyConfig,
],

View File

@ -7,7 +7,6 @@ from uuid import uuid4
import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
from graphrag.data_model.schemas import ENTITIES_FINAL_COLUMNS
from graphrag.index.operations.compute_degree import compute_degree
@ -19,7 +18,6 @@ from graphrag.index.operations.layout_graph.layout_graph import layout_graph
def finalize_entities(
entities: pd.DataFrame,
relationships: pd.DataFrame,
callbacks: WorkflowCallbacks,
embed_config: EmbedGraphConfig | None = None,
layout_enabled: bool = False,
) -> pd.DataFrame:
@ -33,7 +31,6 @@ def finalize_entities(
)
layout = layout_graph(
graph,
callbacks,
layout_enabled,
embeddings=graph_embeddings,
)

View File

@ -3,17 +3,19 @@
"""A module containing layout_graph, _run_layout and _apply_layout_to_graph methods definition."""
import logging
import networkx as nx
import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.operations.embed_graph.typing import NodeEmbeddings
from graphrag.index.operations.layout_graph.typing import GraphLayout
logger = logging.getLogger(__name__)
def layout_graph(
graph: nx.Graph,
callbacks: WorkflowCallbacks,
enabled: bool,
embeddings: NodeEmbeddings | None,
):
@ -44,7 +46,6 @@ def layout_graph(
graph,
enabled,
embeddings if embeddings is not None else {},
callbacks,
)
layout_df = pd.DataFrame(layout)
@ -58,7 +59,6 @@ def _run_layout(
graph: nx.Graph,
enabled: bool,
embeddings: NodeEmbeddings,
callbacks: WorkflowCallbacks,
) -> GraphLayout:
if enabled:
from graphrag.index.operations.layout_graph.umap import (
@ -68,7 +68,9 @@ def _run_layout(
return run_umap(
graph,
embeddings,
lambda e, stack, d: callbacks.error("Error in Umap", e, stack, d),
lambda e, stack, d: logger.error(
"Error in Umap", exc_info=e, extra={"stack": stack, "details": d}
),
)
from graphrag.index.operations.layout_graph.zero import (
run as run_zero,
@ -76,5 +78,7 @@ def _run_layout(
return run_zero(
graph,
lambda e, stack, d: callbacks.error("Error in Zero", e, stack, d),
lambda e, stack, d: logger.error(
"Error in Zero", exc_info=e, extra={"stack": stack, "details": d}
),
)

View File

@ -20,7 +20,7 @@ from graphrag.index.typing.error_handler import ErrorHandlerFn
# for "size" or "cluster"
# We could also have a boolean to indicate to use node sizes or clusters
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
def run(
@ -56,7 +56,7 @@ def run(
**additional_args,
)
except Exception as e:
log.exception("Error running UMAP")
logger.exception("Error running UMAP")
on_error(e, traceback.format_exc(), None)
# Umap may fail due to input sparseness or memory pressure.
# For now, in these cases, we'll just return a layout with all nodes at (0, 0)

View File

@ -18,7 +18,7 @@ from graphrag.index.typing.error_handler import ErrorHandlerFn
# for "size" or "cluster"
# We could also have a boolean to indicate to use node sizes or clusters
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
def run(
@ -47,7 +47,7 @@ def run(
try:
return get_zero_positions(node_labels=nodes, **additional_args)
except Exception as e:
log.exception("Error running zero-position")
logger.exception("Error running zero-position")
on_error(e, traceback.format_exc(), None)
# Umap may fail due to input sparseness or memory pressure.
# For now, in these cases, we'll just return a layout with all nodes at (0, 0)

View File

@ -13,7 +13,7 @@ from graphrag.index.typing.error_handler import ErrorHandlerFn
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# these tokens are used in the prompt
INPUT_TEXT_KEY = "input_text"
@ -86,7 +86,7 @@ class CommunityReportsExtractor:
output = response.parsed_response
except Exception as e:
log.exception("error generating community report")
logger.exception("error generating community report")
self._on_error(e, traceback.format_exc(), None)
text_output = self._get_text_output(output) if output else ""

View File

@ -32,7 +32,7 @@ from graphrag.index.utils.dataframes import (
from graphrag.logger.progress import progress_iterable
from graphrag.query.llm.text_utils import num_tokens
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
def build_local_context(
@ -69,7 +69,7 @@ def _prepare_reports_at_level(
"""Prepare reports at a given level."""
# Filter and prepare node details
level_node_df = node_df[node_df[schemas.COMMUNITY_LEVEL] == level]
log.info("Number of nodes at level=%s => %s", level, len(level_node_df))
logger.info("Number of nodes at level=%s => %s", level, len(level_node_df))
nodes_set = set(level_node_df[schemas.TITLE])
# Filter and prepare edge details

View File

@ -4,7 +4,6 @@
"""A module containing run, _run_extractor and _load_nodes_edges_for_claim_chain methods definition."""
import logging
import traceback
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
@ -21,7 +20,7 @@ from graphrag.index.utils.rate_limiter import RateLimiter
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol.base import ChatModel
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
async def run_graph_intelligence(
@ -42,7 +41,7 @@ async def run_graph_intelligence(
cache=cache,
)
return await _run_extractor(llm, community, input, level, args, callbacks)
return await _run_extractor(llm, community, input, level, args)
async def _run_extractor(
@ -51,7 +50,6 @@ async def _run_extractor(
input: str,
level: int,
args: StrategyConfig,
callbacks: WorkflowCallbacks,
) -> CommunityReport | None:
# RateLimiter
rate_limiter = RateLimiter(rate=1, per=60)
@ -59,8 +57,8 @@ async def _run_extractor(
model,
extraction_prompt=args.get("extraction_prompt", None),
max_report_length=args.get("max_report_length", None),
on_error=lambda e, stack, _data: callbacks.error(
"Community Report Extraction Error", e, stack
on_error=lambda e, stack, _data: logger.error(
"Community Report Extraction Error", exc_info=e, extra={"stack": stack}
),
)
@ -69,7 +67,7 @@ async def _run_extractor(
results = await extractor(input)
report = results.structured_output
if report is None:
log.warning("No report found for community: %s", community)
logger.warning("No report found for community: %s", community)
return None
return CommunityReport(
@ -86,7 +84,6 @@ async def _run_extractor(
],
full_content_json=report.model_dump_json(indent=4),
)
except Exception as e:
log.exception("Error processing community: %s", community)
callbacks.error("Community Report Extraction Error", e, traceback.format_exc())
except Exception:
logger.exception("Error processing community: %s", community)
return None

View File

@ -24,7 +24,7 @@ from graphrag.index.operations.summarize_communities.utils import (
from graphrag.index.utils.derive_from_rows import derive_from_rows
from graphrag.logger.progress import progress_ticker
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
async def summarize_communities(
@ -64,7 +64,7 @@ async def summarize_communities(
)
level_contexts.append(level_context)
for level_context in level_contexts:
for i, level_context in enumerate(level_contexts):
async def run_generate(record):
result = await _generate_report(
@ -85,6 +85,7 @@ async def summarize_communities(
callbacks=NoopWorkflowCallbacks(),
num_threads=num_threads,
async_type=async_mode,
progress_msg=f"level {levels[i]} summarize communities progress: ",
)
reports.extend([lr for lr in local_reports if lr is not None])

View File

@ -20,7 +20,7 @@ from graphrag.index.operations.summarize_communities.text_unit_context.sort_cont
)
from graphrag.query.llm.text_utils import num_tokens
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
def build_local_context(

View File

@ -9,7 +9,7 @@ import pandas as pd
import graphrag.data_model.schemas as schemas
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
def prep_text_units(

View File

@ -10,7 +10,7 @@ import pandas as pd
import graphrag.data_model.schemas as schemas
from graphrag.query.llm.text_utils import num_tokens
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
def get_context_string(

View File

@ -3,8 +3,9 @@
"""A module containing run_graph_intelligence, run_resolve_entities and _create_text_list_splitter methods to run graph intelligence."""
import logging
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.index.operations.summarize_descriptions.description_summary_extractor import (
SummarizeExtractor,
@ -16,11 +17,12 @@ from graphrag.index.operations.summarize_descriptions.typing import (
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol.base import ChatModel
logger = logging.getLogger(__name__)
async def run_graph_intelligence(
id: str | tuple[str, str],
descriptions: list[str],
callbacks: WorkflowCallbacks,
cache: PipelineCache,
args: StrategyConfig,
) -> SummarizedDescriptionResult:
@ -30,18 +32,16 @@ async def run_graph_intelligence(
name="summarize_descriptions",
model_type=llm_config.type,
config=llm_config,
callbacks=callbacks,
cache=cache,
)
return await run_summarize_descriptions(llm, id, descriptions, callbacks, args)
return await run_summarize_descriptions(llm, id, descriptions, args)
async def run_summarize_descriptions(
model: ChatModel,
id: str | tuple[str, str],
descriptions: list[str],
callbacks: WorkflowCallbacks,
args: StrategyConfig,
) -> SummarizedDescriptionResult:
"""Run the entity extraction chain."""
@ -52,10 +52,10 @@ async def run_summarize_descriptions(
extractor = SummarizeExtractor(
model_invoker=model,
summarization_prompt=summarize_prompt,
on_error=lambda e, stack, details: (
callbacks.error("Entity Extraction Error", e, stack, details)
if callbacks
else None
on_error=lambda e, stack, details: logger.error(
"Entity Extraction Error",
exc_info=e,
extra={"stack": stack, "details": details},
),
max_summary_length=max_summary_length,
max_input_tokens=max_input_tokens,

View File

@ -17,7 +17,7 @@ from graphrag.index.operations.summarize_descriptions.typing import (
)
from graphrag.logger.progress import ProgressTicker, progress_ticker
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
async def summarize_descriptions(
@ -29,7 +29,7 @@ async def summarize_descriptions(
num_threads: int = 4,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Summarize entity and relationship descriptions from an entity graph, using a language model."""
log.debug("summarize_descriptions strategy=%s", strategy)
logger.debug("summarize_descriptions strategy=%s", strategy)
strategy = strategy or {}
strategy_exec = load_strategy(
strategy.get("type", SummarizeStrategyType.graph_intelligence)
@ -41,7 +41,11 @@ async def summarize_descriptions(
):
ticker_length = len(nodes) + len(edges)
ticker = progress_ticker(callbacks.progress, ticker_length)
ticker = progress_ticker(
callbacks.progress,
ticker_length,
description="Summarize entity/relationship description progress: ",
)
node_futures = [
do_summarize_descriptions(
@ -95,9 +99,7 @@ async def summarize_descriptions(
semaphore: asyncio.Semaphore,
):
async with semaphore:
results = await strategy_exec(
id, descriptions, callbacks, cache, strategy_config
)
results = await strategy_exec(id, descriptions, cache, strategy_config)
ticker(1)
return results

View File

@ -9,7 +9,6 @@ from enum import Enum
from typing import Any, NamedTuple
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
StrategyConfig = dict[str, Any]
@ -26,7 +25,6 @@ SummarizationStrategy = Callable[
[
str | tuple[str, str],
list[str],
WorkflowCallbacks,
PipelineCache,
StrategyConfig,
],

View File

@ -7,7 +7,6 @@ import json
import logging
import re
import time
import traceback
from collections.abc import AsyncIterable
from dataclasses import asdict
@ -17,20 +16,17 @@ from graphrag.index.run.utils import create_run_context
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.pipeline import Pipeline
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
from graphrag.logger.base import ProgressLogger
from graphrag.logger.progress import Progress
from graphrag.storage.pipeline_storage import PipelineStorage
from graphrag.utils.api import create_cache_from_config, create_storage_from_config
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
async def run_pipeline(
pipeline: Pipeline,
config: GraphRagConfig,
callbacks: WorkflowCallbacks,
logger: ProgressLogger,
is_update_run: bool = False,
) -> AsyncIterable[PipelineRunResult]:
"""Run all workflows using a simplified pipeline."""
@ -66,7 +62,6 @@ async def run_pipeline(
cache=cache,
callbacks=callbacks,
state=state,
progress_logger=logger,
)
else:
@ -78,13 +73,11 @@ async def run_pipeline(
cache=cache,
callbacks=callbacks,
state=state,
progress_logger=logger,
)
async for table in _run_pipeline(
pipeline=pipeline,
config=config,
logger=logger,
context=context,
):
yield table
@ -93,7 +86,6 @@ async def run_pipeline(
async def _run_pipeline(
pipeline: Pipeline,
config: GraphRagConfig,
logger: ProgressLogger,
context: PipelineRunContext,
) -> AsyncIterable[PipelineRunResult]:
start_time = time.time()
@ -103,13 +95,12 @@ async def _run_pipeline(
try:
await _dump_json(context)
logger.info("Executing pipeline...")
for name, workflow_function in pipeline.run():
last_workflow = name
progress = logger.child(name, transient=False)
context.callbacks.workflow_start(name, None)
work_time = time.time()
result = await workflow_function(config, context)
progress(Progress(percent=1))
context.callbacks.workflow_end(name, result)
yield PipelineRunResult(
workflow=name, result=result.result, state=context.state, errors=None
@ -120,11 +111,11 @@ async def _run_pipeline(
break
context.stats.total_runtime = time.time() - start_time
logger.info("Indexing pipeline complete.")
await _dump_json(context)
except Exception as e:
log.exception("error running workflow %s", last_workflow)
context.callbacks.error("Error running pipeline!", e, traceback.format_exc())
logger.exception("error running workflow %s", last_workflow)
yield PipelineRunResult(
workflow=last_workflow, result=None, state=context.state, errors=[e]
)

View File

@ -6,15 +6,12 @@
from graphrag.cache.memory_pipeline_cache import InMemoryCache
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.callbacks.progress_workflow_callbacks import ProgressWorkflowCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.callbacks.workflow_callbacks_manager import WorkflowCallbacksManager
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.state import PipelineState
from graphrag.index.typing.stats import PipelineRunStats
from graphrag.logger.base import ProgressLogger
from graphrag.logger.null_progress import NullProgressLogger
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
from graphrag.storage.pipeline_storage import PipelineStorage
from graphrag.utils.api import create_storage_from_config
@ -26,7 +23,6 @@ def create_run_context(
previous_storage: PipelineStorage | None = None,
cache: PipelineCache | None = None,
callbacks: WorkflowCallbacks | None = None,
progress_logger: ProgressLogger | None = None,
stats: PipelineRunStats | None = None,
state: PipelineState | None = None,
) -> PipelineRunContext:
@ -37,21 +33,18 @@ def create_run_context(
previous_storage=previous_storage or MemoryPipelineStorage(),
cache=cache or InMemoryCache(),
callbacks=callbacks or NoopWorkflowCallbacks(),
progress_logger=progress_logger or NullProgressLogger(),
stats=stats or PipelineRunStats(),
state=state or {},
)
def create_callback_chain(
callbacks: list[WorkflowCallbacks] | None, progress: ProgressLogger | None
callbacks: list[WorkflowCallbacks] | None,
) -> WorkflowCallbacks:
"""Create a callback manager that encompasses multiple callbacks."""
manager = WorkflowCallbacksManager()
for callback in callbacks or []:
manager.register(callback)
if progress is not None:
manager.register(ProgressWorkflowCallbacks(progress))
return manager

View File

@ -21,7 +21,7 @@ DecodeFn = Callable[[EncodedText], str]
EncodeFn = Callable[[str], EncodedText]
LengthFn = Callable[[str], int]
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
@ -100,7 +100,9 @@ class TokenTextSplitter(TextSplitter):
try:
enc = tiktoken.encoding_for_model(model_name)
except KeyError:
log.exception("Model %s not found, using %s", model_name, encoding_name)
logger.exception(
"Model %s not found, using %s", model_name, encoding_name
)
enc = tiktoken.get_encoding(encoding_name)
else:
enc = tiktoken.get_encoding(encoding_name)

View File

@ -10,7 +10,6 @@ from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.typing.state import PipelineState
from graphrag.index.typing.stats import PipelineRunStats
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
@ -29,7 +28,5 @@ class PipelineRunContext:
"Cache instance for reading previous LLM responses."
callbacks: WorkflowCallbacks
"Callbacks to be called during the pipeline run."
progress_logger: ProgressLogger
"Progress logger for the pipeline run."
state: PipelineState
"Arbitrary property bag for runtime state, persistent pre-computes, or experimental features."

View File

@ -37,17 +37,18 @@ async def derive_from_rows(
callbacks: WorkflowCallbacks | None = None,
num_threads: int = 4,
async_type: AsyncType = AsyncType.AsyncIO,
progress_msg: str = "",
) -> list[ItemType | None]:
"""Apply a generic transform function to each row. Any errors will be reported and thrown."""
callbacks = callbacks or NoopWorkflowCallbacks()
match async_type:
case AsyncType.AsyncIO:
return await derive_from_rows_asyncio(
input, transform, callbacks, num_threads
input, transform, callbacks, num_threads, progress_msg
)
case AsyncType.Threaded:
return await derive_from_rows_asyncio_threads(
input, transform, callbacks, num_threads
input, transform, callbacks, num_threads, progress_msg
)
case _:
msg = f"Unsupported scheduling type {async_type}"
@ -62,6 +63,7 @@ async def derive_from_rows_asyncio_threads(
transform: Callable[[pd.Series], Awaitable[ItemType]],
callbacks: WorkflowCallbacks,
num_threads: int | None = 4,
progress_msg: str = "",
) -> list[ItemType | None]:
"""
Derive from rows asynchronously.
@ -81,7 +83,9 @@ async def derive_from_rows_asyncio_threads(
return await asyncio.gather(*[execute_task(task) for task in tasks])
return await _derive_from_rows_base(input, transform, callbacks, gather)
return await _derive_from_rows_base(
input, transform, callbacks, gather, progress_msg
)
"""A module containing the derive_from_rows_async method."""
@ -92,6 +96,7 @@ async def derive_from_rows_asyncio(
transform: Callable[[pd.Series], Awaitable[ItemType]],
callbacks: WorkflowCallbacks,
num_threads: int = 4,
progress_msg: str = "",
) -> list[ItemType | None]:
"""
Derive from rows asynchronously.
@ -112,7 +117,9 @@ async def derive_from_rows_asyncio(
]
return await asyncio.gather(*tasks)
return await _derive_from_rows_base(input, transform, callbacks, gather)
return await _derive_from_rows_base(
input, transform, callbacks, gather, progress_msg
)
ItemType = TypeVar("ItemType")
@ -126,13 +133,16 @@ async def _derive_from_rows_base(
transform: Callable[[pd.Series], Awaitable[ItemType]],
callbacks: WorkflowCallbacks,
gather: GatherFn[ItemType],
progress_msg: str = "",
) -> list[ItemType | None]:
"""
Derive from rows asynchronously.
This is useful for IO bound operations.
"""
tick = progress_ticker(callbacks.progress, num_total=len(input))
tick = progress_ticker(
callbacks.progress, num_total=len(input), description=progress_msg
)
errors: list[tuple[BaseException, str]] = []
async def execute(row: tuple[Any, pd.Series]) -> ItemType | None:
@ -153,7 +163,9 @@ async def _derive_from_rows_base(
tick.done()
for error, stack in errors:
callbacks.error("parallel transformation error", error, stack)
logger.error(
"parallel transformation error", exc_info=error, extra={"stack": stack}
)
if len(errors) > 0:
raise ParallelizationError(len(errors), errors[0][1])

View File

@ -11,7 +11,7 @@ import graphrag.config.defaults as defs
DEFAULT_ENCODING_NAME = defs.ENCODING_MODEL
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
def num_tokens_from_string(
@ -23,7 +23,7 @@ def num_tokens_from_string(
encoding = tiktoken.encoding_for_model(model)
except KeyError:
msg = f"Failed to get encoding for {model} when getting num_tokens_from_string. Fall back to default encoding {DEFAULT_ENCODING_NAME}"
log.warning(msg)
logger.warning(msg)
encoding = tiktoken.get_encoding(DEFAULT_ENCODING_NAME)
else:
encoding = tiktoken.get_encoding(encoding_name or DEFAULT_ENCODING_NAME)

View File

@ -4,15 +4,17 @@
"""A module containing validate_config_names definition."""
import asyncio
import logging
import sys
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.language_model.manager import ModelManager
from graphrag.logger.print_progress import ProgressLogger
logger = logging.getLogger(__name__)
def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) -> None:
def validate_config_names(parameters: GraphRagConfig) -> None:
"""Validate config file for LLM deployment name typos."""
# Validate Chat LLM configs
# TODO: Replace default_chat_model with a way to select the model
@ -28,7 +30,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
try:
asyncio.run(llm.achat("This is an LLM connectivity test. Say Hello World"))
logger.success("LLM Config Params Validated")
logger.info("LLM Config Params Validated")
except Exception as e: # noqa: BLE001
logger.error(f"LLM configuration error detected. Exiting...\n{e}") # noqa
sys.exit(1)
@ -48,7 +50,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
try:
asyncio.run(embed_llm.aembed_batch(["This is an LLM Embedding Test String"]))
logger.success("Embedding LLM Config Params Validated")
logger.info("Embedding LLM Config Params Validated")
except Exception as e: # noqa: BLE001
logger.error(f"Embedding LLM configuration error detected. Exiting...\n{e}") # noqa
sys.exit(1)

View File

@ -4,6 +4,7 @@
"""A module containing run_workflow method definition."""
import json
import logging
from typing import Any, cast
import pandas as pd
@ -19,12 +20,15 @@ from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.logger.progress import Progress
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to transform base text_units."""
logger.info("Workflow started: create_base_text_units")
documents = await load_table_from_storage("documents", context.output_storage)
chunks = config.chunks
@ -43,6 +47,7 @@ async def run_workflow(
await write_table_to_storage(output, "text_units", context.output_storage)
logger.info("Workflow completed: create_base_text_units")
return WorkflowFunctionOutput(result=output)
@ -81,7 +86,7 @@ def create_base_text_units(
)
aggregated.rename(columns={"text_with_ids": "texts"}, inplace=True)
def chunker(row: dict[str, Any]) -> Any:
def chunker(row: pd.Series) -> Any:
line_delimiter = ".\n"
metadata_str = ""
metadata_tokens = 0
@ -125,7 +130,19 @@ def create_base_text_units(
row["chunks"] = chunked
return row
aggregated = aggregated.apply(lambda row: chunker(row), axis=1)
# Track progress of row-wise apply operation
total_rows = len(aggregated)
logger.info("Starting chunking process for %d documents", total_rows)
def chunker_with_logging(row: pd.Series, row_index: int) -> Any:
"""Add logging to chunker execution."""
result = chunker(row)
logger.info("chunker progress: %d/%d", row_index + 1, total_rows)
return result
aggregated = aggregated.apply(
lambda row: chunker_with_logging(row, row.name), axis=1
)
aggregated = cast("pd.DataFrame", aggregated[[*group_by_columns, "chunks"]])
aggregated = aggregated.explode("chunks")

View File

@ -3,6 +3,7 @@
"""A module containing run_workflow method definition."""
import logging
from datetime import datetime, timezone
from typing import cast
from uuid import uuid4
@ -18,12 +19,15 @@ from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to transform final communities."""
logger.info("Workflow started: create_communities")
entities = await load_table_from_storage("entities", context.output_storage)
relationships = await load_table_from_storage(
"relationships", context.output_storage
@ -43,6 +47,7 @@ async def run_workflow(
await write_table_to_storage(output, "communities", context.output_storage)
logger.info("Workflow completed: create_communities")
return WorkflowFunctionOutput(result=output)

View File

@ -3,6 +3,8 @@
"""A module containing run_workflow method definition."""
import logging
import pandas as pd
import graphrag.data_model.schemas as schemas
@ -32,12 +34,15 @@ from graphrag.utils.storage import (
write_table_to_storage,
)
logger = logging.getLogger(__name__)
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to transform community reports."""
logger.info("Workflow started: create_community_reports")
edges = await load_table_from_storage("relationships", context.output_storage)
entities = await load_table_from_storage("entities", context.output_storage)
communities = await load_table_from_storage("communities", context.output_storage)
@ -70,6 +75,7 @@ async def run_workflow(
await write_table_to_storage(output, "community_reports", context.output_storage)
logger.info("Workflow completed: create_community_reports")
return WorkflowFunctionOutput(result=output)

View File

@ -29,7 +29,7 @@ from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
async def run_workflow(
@ -37,6 +37,7 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to transform community reports."""
logger.info("Workflow started: create_community_reports_text")
entities = await load_table_from_storage("entities", context.output_storage)
communities = await load_table_from_storage("communities", context.output_storage)
@ -64,6 +65,7 @@ async def run_workflow(
await write_table_to_storage(output, "community_reports", context.output_storage)
logger.info("Workflow completed: create_community_reports_text")
return WorkflowFunctionOutput(result=output)

View File

@ -3,6 +3,8 @@
"""A module containing run_workflow method definition."""
import logging
import pandas as pd
from graphrag.config.models.graph_rag_config import GraphRagConfig
@ -11,12 +13,15 @@ from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)
async def run_workflow(
_config: GraphRagConfig,
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to transform final documents."""
logger.info("Workflow started: create_final_documents")
documents = await load_table_from_storage("documents", context.output_storage)
text_units = await load_table_from_storage("text_units", context.output_storage)
@ -24,6 +29,7 @@ async def run_workflow(
await write_table_to_storage(output, "documents", context.output_storage)
logger.info("Workflow completed: create_final_documents")
return WorkflowFunctionOutput(result=output)

View File

@ -3,6 +3,8 @@
"""A module containing run_workflow method definition."""
import logging
import pandas as pd
from graphrag.config.models.graph_rag_config import GraphRagConfig
@ -15,12 +17,15 @@ from graphrag.utils.storage import (
write_table_to_storage,
)
logger = logging.getLogger(__name__)
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to transform the text units."""
logger.info("Workflow started: create_final_text_units")
text_units = await load_table_from_storage("text_units", context.output_storage)
final_entities = await load_table_from_storage("entities", context.output_storage)
final_relationships = await load_table_from_storage(
@ -43,6 +48,7 @@ async def run_workflow(
await write_table_to_storage(output, "text_units", context.output_storage)
logger.info("Workflow completed: create_final_text_units")
return WorkflowFunctionOutput(result=output)

View File

@ -3,6 +3,7 @@
"""A module containing run_workflow method definition."""
import logging
from typing import Any
from uuid import uuid4
@ -20,12 +21,15 @@ from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to extract and format covariates."""
logger.info("Workflow started: extract_covariates")
output = None
if config.extract_claims.enabled:
text_units = await load_table_from_storage("text_units", context.output_storage)
@ -53,6 +57,7 @@ async def run_workflow(
await write_table_to_storage(output, "covariates", context.output_storage)
logger.info("Workflow completed: extract_covariates")
return WorkflowFunctionOutput(result=output)

View File

@ -3,6 +3,7 @@
"""A module containing run_workflow method definition."""
import logging
from typing import Any
import pandas as pd
@ -21,12 +22,15 @@ from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to create the base entity graph."""
logger.info("Workflow started: extract_graph")
text_units = await load_table_from_storage("text_units", context.output_storage)
extract_graph_llm_settings = config.get_language_model_config(
@ -66,6 +70,7 @@ async def run_workflow(
raw_relationships, "raw_relationships", context.output_storage
)
logger.info("Workflow completed: extract_graph")
return WorkflowFunctionOutput(
result={
"entities": entities,
@ -101,14 +106,14 @@ async def extract_graph(
if not _validate_data(extracted_entities):
error_msg = "Entity Extraction failed. No entities detected during extraction."
callbacks.error(error_msg)
logger.error(error_msg)
raise ValueError(error_msg)
if not _validate_data(extracted_relationships):
error_msg = (
"Entity Extraction failed. No relationships detected during extraction."
)
callbacks.error(error_msg)
logger.error(error_msg)
raise ValueError(error_msg)
# copy these as is before any summarization

View File

@ -3,6 +3,8 @@
"""A module containing run_workflow method definition."""
import logging
import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
@ -16,12 +18,15 @@ from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to create the base entity graph."""
logger.info("Workflow started: extract_graph_nlp")
text_units = await load_table_from_storage("text_units", context.output_storage)
entities, relationships = await extract_graph_nlp(
@ -33,6 +38,8 @@ async def run_workflow(
await write_table_to_storage(entities, "entities", context.output_storage)
await write_table_to_storage(relationships, "relationships", context.output_storage)
logger.info("Workflow completed: extract_graph_nlp")
return WorkflowFunctionOutput(
result={
"entities": entities,

View File

@ -3,6 +3,7 @@
"""Encapsulates pipeline construction and selection."""
import logging
from typing import ClassVar
from graphrag.config.enums import IndexingMethod
@ -10,6 +11,8 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.typing.pipeline import Pipeline
from graphrag.index.typing.workflow import WorkflowFunction
logger = logging.getLogger(__name__)
class PipelineFactory:
"""A factory class for workflow pipelines."""
@ -41,6 +44,7 @@ class PipelineFactory:
) -> Pipeline:
"""Create a pipeline generator."""
workflows = config.workflows or cls.pipelines.get(method, [])
logger.info("Creating pipeline with workflows: %s", workflows)
return Pipeline([(name, cls.workflows[name]) for name in workflows])

View File

@ -3,9 +3,10 @@
"""A module containing run_workflow method definition."""
import logging
import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.operations.create_graph import create_graph
@ -16,12 +17,15 @@ from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to create the base entity graph."""
logger.info("Workflow started: finalize_graph")
entities = await load_table_from_storage("entities", context.output_storage)
relationships = await load_table_from_storage(
"relationships", context.output_storage
@ -30,7 +34,6 @@ async def run_workflow(
final_entities, final_relationships = finalize_graph(
entities,
relationships,
callbacks=context.callbacks,
embed_config=config.embed_graph,
layout_enabled=config.umap.enabled,
)
@ -50,6 +53,7 @@ async def run_workflow(
storage=context.output_storage,
)
logger.info("Workflow completed: finalize_graph")
return WorkflowFunctionOutput(
result={
"entities": entities,
@ -61,13 +65,12 @@ async def run_workflow(
def finalize_graph(
entities: pd.DataFrame,
relationships: pd.DataFrame,
callbacks: WorkflowCallbacks,
embed_config: EmbedGraphConfig | None = None,
layout_enabled: bool = False,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""All the steps to finalize the entity and relationship formats."""
final_entities = finalize_entities(
entities, relationships, callbacks, embed_config, layout_enabled
entities, relationships, embed_config, layout_enabled
)
final_relationships = finalize_relationships(relationships)
return (final_entities, final_relationships)

View File

@ -30,7 +30,7 @@ from graphrag.utils.storage import (
write_table_to_storage,
)
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
async def run_workflow(
@ -38,6 +38,7 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to transform community reports."""
logger.info("Workflow started: generate_text_embeddings")
documents = None
relationships = None
text_units = None
@ -81,6 +82,7 @@ async def run_workflow(
context.output_storage,
)
logger.info("Workflow completed: generate_text_embeddings")
return WorkflowFunctionOutput(result=output)
@ -145,12 +147,12 @@ async def generate_text_embeddings(
},
}
log.info("Creating embeddings")
logger.info("Creating embeddings")
outputs = {}
for field in embedded_fields:
if embedding_param_map[field]["data"] is None:
msg = f"Embedding {field} is specified but data table is not in storage. This may or may not be intentional - if you expect it to me here, please check for errors earlier in the logs."
log.warning(msg)
logger.warning(msg)
else:
outputs[field] = await _run_embeddings(
name=field,

View File

@ -12,11 +12,10 @@ from graphrag.config.models.input_config import InputConfig
from graphrag.index.input.factory import create_input
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
from graphrag.utils.storage import write_table_to_storage
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
async def run_workflow(
@ -27,10 +26,9 @@ async def run_workflow(
output = await load_input_documents(
config.input,
context.input_storage,
context.progress_logger,
)
log.info("Final # of rows loaded: %s", len(output))
logger.info("Final # of rows loaded: %s", len(output))
context.stats.num_documents = len(output)
await write_table_to_storage(output, "documents", context.output_storage)
@ -39,7 +37,7 @@ async def run_workflow(
async def load_input_documents(
config: InputConfig, storage: PipelineStorage, progress_logger: ProgressLogger
config: InputConfig, storage: PipelineStorage
) -> pd.DataFrame:
"""Load and parse input documents into a standard format."""
return await create_input(config, storage, progress_logger)
return await create_input(config, storage)

View File

@ -13,11 +13,10 @@ from graphrag.index.input.factory import create_input
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.index.update.incremental_index import get_delta_docs
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
from graphrag.utils.storage import write_table_to_storage
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
async def run_workflow(
@ -29,15 +28,13 @@ async def run_workflow(
config.input,
context.input_storage,
context.previous_storage,
context.progress_logger,
)
log.info("Final # of update rows loaded: %s", len(output))
logger.info("Final # of update rows loaded: %s", len(output))
context.stats.update_documents = len(output)
if len(output) == 0:
log.warning("No new update documents found.")
context.progress_logger.warning("No new update documents found.")
logger.warning("No new update documents found.")
return WorkflowFunctionOutput(result=None, stop=True)
await write_table_to_storage(output, "documents", context.output_storage)
@ -49,10 +46,9 @@ async def load_update_documents(
config: InputConfig,
input_storage: PipelineStorage,
previous_storage: PipelineStorage,
progress_logger: ProgressLogger,
) -> pd.DataFrame:
"""Load and parse update-only input documents into a standard format."""
input_documents = await create_input(config, input_storage, progress_logger)
input_documents = await create_input(config, input_storage)
# previous storage is the output of the previous run
# we'll use this to diff the input from the prior
delta_documents = await get_delta_docs(input_documents, previous_storage)

View File

@ -3,6 +3,8 @@
"""A module containing run_workflow method definition."""
import logging
import pandas as pd
from graphrag.config.models.graph_rag_config import GraphRagConfig
@ -14,12 +16,15 @@ from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to create the base entity graph."""
logger.info("Workflow started: prune_graph")
entities = await load_table_from_storage("entities", context.output_storage)
relationships = await load_table_from_storage(
"relationships", context.output_storage
@ -36,6 +41,7 @@ async def run_workflow(
pruned_relationships, "relationships", context.output_storage
)
logger.info("Workflow completed: prune_graph")
return WorkflowFunctionOutput(
result={
"entities": pruned_entities,

View File

@ -17,7 +17,7 @@ async def run_workflow( # noqa: RUF029
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""Clean the state after the update."""
logger.info("Cleaning State")
logger.info("Workflow started: update_clean_state")
keys_to_delete = [
key_name
for key_name in context.state
@ -27,4 +27,5 @@ async def run_workflow( # noqa: RUF029
for key_name in keys_to_delete:
del context.state[key_name]
logger.info("Workflow completed: update_clean_state")
return WorkflowFunctionOutput(result=None)

View File

@ -21,7 +21,7 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""Update the communities from a incremental index run."""
logger.info("Updating Communities")
logger.info("Workflow started: update_communities")
output_storage, previous_storage, delta_storage = get_update_storages(
config, context.state["update_timestamp"]
)
@ -32,6 +32,7 @@ async def run_workflow(
context.state["incremental_update_community_id_mapping"] = community_id_mapping
logger.info("Workflow completed: update_communities")
return WorkflowFunctionOutput(result=None)

View File

@ -23,7 +23,7 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""Update the community reports from a incremental index run."""
logger.info("Updating Community Reports")
logger.info("Workflow started: update_community_reports")
output_storage, previous_storage, delta_storage = get_update_storages(
config, context.state["update_timestamp"]
)
@ -38,6 +38,7 @@ async def run_workflow(
merged_community_reports
)
logger.info("Workflow completed: update_community_reports")
return WorkflowFunctionOutput(result=None)

View File

@ -27,6 +27,7 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""Update the covariates from a incremental index run."""
logger.info("Workflow started: update_covariates")
output_storage, previous_storage, delta_storage = get_update_storages(
config, context.state["update_timestamp"]
)
@ -37,6 +38,7 @@ async def run_workflow(
logger.info("Updating Covariates")
await _update_covariates(previous_storage, delta_storage, output_storage)
logger.info("Workflow completed: update_covariates")
return WorkflowFunctionOutput(result=None)

View File

@ -27,7 +27,7 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""Update the entities and relationships from a incremental index run."""
logger.info("Updating Entities and Relationships")
logger.info("Workflow started: update_entities_relationships")
output_storage, previous_storage, delta_storage = get_update_storages(
config, context.state["update_timestamp"]
)
@ -49,6 +49,7 @@ async def run_workflow(
context.state["incremental_update_merged_relationships"] = merged_relationships_df
context.state["incremental_update_entity_id_mapping"] = entity_id_mapping
logger.info("Workflow completed: update_entities_relationships")
return WorkflowFunctionOutput(result=None)

View File

@ -19,7 +19,7 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""Update the documents from a incremental index run."""
logger.info("Updating Documents")
logger.info("Workflow started: update_final_documents")
output_storage, previous_storage, delta_storage = get_update_storages(
config, context.state["update_timestamp"]
)
@ -30,4 +30,5 @@ async def run_workflow(
context.state["incremental_update_final_documents"] = final_documents
logger.info("Workflow completed: update_final_documents")
return WorkflowFunctionOutput(result=None)

View File

@ -21,7 +21,7 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""Update the text embeddings from a incremental index run."""
logger.info("Updating Text Embeddings")
logger.info("Workflow started: update_text_embeddings")
output_storage, _, _ = get_update_storages(
config, context.state["update_timestamp"]
)
@ -55,4 +55,5 @@ async def run_workflow(
output_storage,
)
logger.info("Workflow completed: update_text_embeddings")
return WorkflowFunctionOutput(result=None)

View File

@ -23,7 +23,7 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""Update the text units from a incremental index run."""
logger.info("Updating Text Units")
logger.info("Workflow started: update_text_units")
output_storage, previous_storage, delta_storage = get_update_storages(
config, context.state["update_timestamp"]
)
@ -35,6 +35,7 @@ async def run_workflow(
context.state["incremental_update_merged_text_units"] = merged_text_units
logger.info("Workflow completed: update_text_units")
return WorkflowFunctionOutput(result=None)

View File

@ -11,6 +11,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, ClassVar
from typing_extensions import Self
from graphrag.language_model.factory import ModelFactory
if TYPE_CHECKING:
@ -22,11 +24,11 @@ class ModelManager:
_instance: ClassVar[ModelManager | None] = None
def __new__(cls) -> ModelManager: # noqa: PYI034: False positive
def __new__(cls) -> Self:
"""Create a new instance of LLMManager if it does not exist."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
return cls._instance # type: ignore[return-value]
def __init__(self) -> None:
# Avoid reinitialization in the singleton.

View File

@ -6,6 +6,7 @@
from __future__ import annotations
import asyncio
import logging
import threading
from typing import TYPE_CHECKING, Any, TypeVar
@ -26,6 +27,8 @@ if TYPE_CHECKING:
)
from graphrag.index.typing.error_handler import ErrorHandlerFn
logger = logging.getLogger(__name__)
def _create_cache(cache: PipelineCache | None, name: str) -> FNLLMCacheProvider | None:
"""Create an FNLLM cache from a pipeline cache."""
@ -34,7 +37,7 @@ def _create_cache(cache: PipelineCache | None, name: str) -> FNLLMCacheProvider
return FNLLMCacheProvider(cache).child(name)
def _create_error_handler(callbacks: WorkflowCallbacks) -> ErrorHandlerFn:
def _create_error_handler(callbacks: WorkflowCallbacks) -> ErrorHandlerFn: # noqa: ARG001
"""Create an error handler from a WorkflowCallbacks."""
def on_error(
@ -42,7 +45,11 @@ def _create_error_handler(callbacks: WorkflowCallbacks) -> ErrorHandlerFn:
stack: str | None = None,
details: dict | None = None,
) -> None:
callbacks.error("Error Invoking LLM", error, stack, details)
logger.error(
"Error Invoking LLM",
exc_info=error,
extra={"stack": stack, "details": details},
)
return on_error

View File

@ -1,69 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Base classes for logging and progress reporting."""
from abc import ABC, abstractmethod
from typing import Any
from graphrag.logger.progress import Progress
class StatusLogger(ABC):
"""Provides a way to log status updates from the pipeline."""
@abstractmethod
def error(self, message: str, details: dict[str, Any] | None = None):
"""Log an error."""
@abstractmethod
def warning(self, message: str, details: dict[str, Any] | None = None):
"""Log a warning."""
@abstractmethod
def log(self, message: str, details: dict[str, Any] | None = None):
"""Report a log."""
class ProgressLogger(ABC):
"""
Abstract base class for progress loggers.
This is used to report workflow processing progress via mechanisms like progress-bars.
"""
@abstractmethod
def __call__(self, update: Progress):
"""Update progress."""
@abstractmethod
def dispose(self):
"""Dispose of the progress logger."""
@abstractmethod
def child(self, prefix: str, transient=True) -> "ProgressLogger":
"""Create a child progress bar."""
@abstractmethod
def force_refresh(self) -> None:
"""Force a refresh."""
@abstractmethod
def stop(self) -> None:
"""Stop the progress logger."""
@abstractmethod
def error(self, message: str) -> None:
"""Log an error."""
@abstractmethod
def warning(self, message: str) -> None:
"""Log a warning."""
@abstractmethod
def info(self, message: str) -> None:
"""Log information."""
@abstractmethod
def success(self, message: str) -> None:
"""Log success."""

View File

@ -4,6 +4,7 @@
"""A logger that emits updates from the indexing engine to a blob in Azure Storage."""
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
@ -11,11 +12,9 @@ from typing import Any
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
"""A logger that writes to a blob storage account."""
class BlobWorkflowLogger(logging.Handler):
"""A logging handler that writes to a blob storage account."""
_blob_service_client: BlobServiceClient
_container_name: str
@ -28,16 +27,21 @@ class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
blob_name: str = "",
base_dir: str | None = None,
storage_account_blob_url: str | None = None,
level: int = logging.NOTSET,
):
"""Create a new instance of the BlobStorageReporter class."""
"""Create a new instance of the BlobWorkflowLogger class."""
super().__init__(level)
if container_name is None:
msg = "No container name provided for blob storage."
raise ValueError(msg)
if connection_string is None and storage_account_blob_url is None:
msg = "No storage account blob url provided for blob storage."
raise ValueError(msg)
self._connection_string = connection_string
self._storage_account_blob_url = storage_account_blob_url
if self._connection_string:
self._blob_service_client = BlobServiceClient.from_connection_string(
self._connection_string
@ -65,7 +69,37 @@ class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
self._num_blocks = 0 # refresh block counter
def emit(self, record) -> None:
"""Emit a log record to blob storage."""
try:
# Create JSON structure based on record
log_data = {
"type": self._get_log_type(record.levelno),
"data": record.getMessage(),
}
# Add additional fields if they exist
if hasattr(record, "details") and record.details: # type: ignore[reportAttributeAccessIssue]
log_data["details"] = record.details # type: ignore[reportAttributeAccessIssue]
if record.exc_info and record.exc_info[1]:
log_data["cause"] = str(record.exc_info[1])
if hasattr(record, "stack") and record.stack: # type: ignore[reportAttributeAccessIssue]
log_data["stack"] = record.stack # type: ignore[reportAttributeAccessIssue]
self._write_log(log_data)
except (OSError, ValueError):
self.handleError(record)
def _get_log_type(self, level: int) -> str:
"""Get log type string based on log level."""
if level >= logging.ERROR:
return "error"
if level >= logging.WARNING:
return "warning"
return "log"
def _write_log(self, log: dict[str, Any]):
"""Write log data to blob storage."""
# create a new file when block count hits close 25k
if (
self._num_blocks >= self._max_block_count
@ -83,27 +117,3 @@ class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
# update the blob's block count
self._num_blocks += 1
def error(
self,
message: str,
cause: BaseException | None = None,
stack: str | None = None,
details: dict | None = None,
):
"""Report an error."""
self._write_log({
"type": "error",
"data": message,
"cause": str(cause),
"stack": stack,
"details": details,
})
def warning(self, message: str, details: dict | None = None):
"""Report a warning."""
self._write_log({"type": "warning", "data": message, "details": details})
def log(self, message: str, details: dict | None = None):
"""Report a generic log message."""
self._write_log({"type": "log", "data": message, "details": details})

View File

@ -1,28 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Console Log."""
from typing import Any
from graphrag.logger.base import StatusLogger
class ConsoleReporter(StatusLogger):
"""A logger that writes to a console."""
def error(self, message: str, details: dict[str, Any] | None = None):
"""Log an error."""
print(message, details) # noqa T201
def warning(self, message: str, details: dict[str, Any] | None = None):
"""Log a warning."""
_print_warning(message)
def log(self, message: str, details: dict[str, Any] | None = None):
"""Log a log."""
print(message, details) # noqa T201
def _print_warning(skk):
print(f"\033[93m {skk}\033[00m") # noqa T201

View File

@ -1,43 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Factory functions for creating loggers."""
from typing import ClassVar
from graphrag.logger.base import ProgressLogger
from graphrag.logger.null_progress import NullProgressLogger
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.logger.rich_progress import RichProgressLogger
from graphrag.logger.types import LoggerType
class LoggerFactory:
"""A factory class for loggers."""
logger_types: ClassVar[dict[str, type]] = {}
@classmethod
def register(cls, logger_type: str, logger: type):
"""Register a custom logger implementation."""
cls.logger_types[logger_type] = logger
@classmethod
def create_logger(
cls, logger_type: LoggerType | str, kwargs: dict | None = None
) -> ProgressLogger:
"""Create a logger based on the provided type."""
if kwargs is None:
kwargs = {}
match logger_type:
case LoggerType.RICH:
return RichProgressLogger("GraphRAG Indexer ")
case LoggerType.PRINT:
return PrintProgressLogger("GraphRAG Indexer ")
case LoggerType.NONE:
return NullProgressLogger()
case _:
if logger_type in cls.logger_types:
return cls.logger_types[logger_type](**kwargs)
# default to null logger if no other logger is found
return NullProgressLogger()

View File

@ -1,38 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Null Progress Reporter."""
from graphrag.logger.base import Progress, ProgressLogger
class NullProgressLogger(ProgressLogger):
"""A progress logger that does nothing."""
def __call__(self, update: Progress) -> None:
"""Update progress."""
def dispose(self) -> None:
"""Dispose of the progress logger."""
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
"""Create a child progress bar."""
return self
def force_refresh(self) -> None:
"""Force a refresh."""
def stop(self) -> None:
"""Stop the progress logger."""
def error(self, message: str) -> None:
"""Log an error."""
def warning(self, message: str) -> None:
"""Log a warning."""
def info(self, message: str) -> None:
"""Log information."""
def success(self, message: str) -> None:
"""Log success."""

View File

@ -1,50 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Print Progress Logger."""
from graphrag.logger.base import Progress, ProgressLogger
class PrintProgressLogger(ProgressLogger):
"""A progress logger that prints progress to stdout."""
prefix: str
def __init__(self, prefix: str):
"""Create a new progress logger."""
self.prefix = prefix
print(f"\n{self.prefix}", end="") # noqa T201
def __call__(self, update: Progress) -> None:
"""Update progress."""
print(".", end="") # noqa T201
def dispose(self) -> None:
"""Dispose of the progress logger."""
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
"""Create a child progress bar."""
return PrintProgressLogger(prefix)
def stop(self) -> None:
"""Stop the progress logger."""
def force_refresh(self) -> None:
"""Force a refresh."""
def error(self, message: str) -> None:
"""Log an error."""
print(f"\n{self.prefix}ERROR: {message}") # noqa T201
def warning(self, message: str) -> None:
"""Log a warning."""
print(f"\n{self.prefix}WARNING: {message}") # noqa T201
def info(self, message: str) -> None:
"""Log information."""
print(f"\n{self.prefix}INFO: {message}") # noqa T201
def success(self, message: str) -> None:
"""Log success."""
print(f"\n{self.prefix}SUCCESS: {message}") # noqa T201

View File

@ -1,13 +1,15 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Progress reporting types."""
"""Progress Logging Utilities."""
import logging
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from typing import TypeVar
T = TypeVar("T")
logger = logging.getLogger(__name__)
@dataclass
@ -24,7 +26,7 @@ class Progress:
"""Total number of items"""
completed_items: int | None = None
"""Number of items completed""" ""
"""Number of items completed"""
ProgressHandler = Callable[[Progress], None]
@ -35,11 +37,15 @@ class ProgressTicker:
"""A class that emits progress reports incrementally."""
_callback: ProgressHandler | None
_description: str
_num_total: int
_num_complete: int
def __init__(self, callback: ProgressHandler | None, num_total: int):
def __init__(
self, callback: ProgressHandler | None, num_total: int, description: str = ""
):
self._callback = callback
self._description = description
self._num_total = num_total
self._num_complete = 0
@ -47,35 +53,47 @@ class ProgressTicker:
"""Emit progress."""
self._num_complete += num_ticks
if self._callback is not None:
self._callback(
Progress(
total_items=self._num_total, completed_items=self._num_complete
)
p = Progress(
total_items=self._num_total,
completed_items=self._num_complete,
description=self._description,
)
if p.description:
logger.info(
"%s%s/%s", p.description, str(p.completed_items), str(p.total_items)
)
self._callback(p)
def done(self) -> None:
"""Mark the progress as done."""
if self._callback is not None:
self._callback(
Progress(total_items=self._num_total, completed_items=self._num_total)
Progress(
total_items=self._num_total,
completed_items=self._num_total,
description=self._description,
)
)
def progress_ticker(callback: ProgressHandler | None, num_total: int) -> ProgressTicker:
def progress_ticker(
callback: ProgressHandler | None, num_total: int, description: str = ""
) -> ProgressTicker:
"""Create a progress ticker."""
return ProgressTicker(callback, num_total)
return ProgressTicker(callback, num_total, description=description)
def progress_iterable(
iterable: Iterable[T],
progress: ProgressHandler | None,
num_total: int | None = None,
description: str = "",
) -> Iterable[T]:
"""Wrap an iterable with a progress handler. Every time an item is yielded, the progress handler will be called with the current progress."""
if num_total is None:
num_total = len(list(iterable))
tick = ProgressTicker(progress, num_total)
tick = ProgressTicker(progress, num_total, description=description)
for item in iterable:
tick(1)

View File

@ -1,165 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Rich-based progress logger for CLI use."""
# Print iterations progress
import asyncio
from rich.console import Console, Group
from rich.live import Live
from rich.progress import Progress, TaskID, TimeElapsedColumn
from rich.spinner import Spinner
from rich.tree import Tree
from graphrag.logger.base import ProgressLogger
from graphrag.logger.progress import Progress as GRProgress
# https://stackoverflow.com/a/34325723
class RichProgressLogger(ProgressLogger):
"""A rich-based progress logger for CLI use."""
_console: Console
_group: Group
_tree: Tree
_live: Live
_task: TaskID | None = None
_prefix: str
_transient: bool
_disposing: bool = False
_progressbar: Progress
_last_refresh: float = 0
def dispose(self) -> None:
"""Dispose of the progress logger."""
self._disposing = True
self._live.stop()
@property
def console(self) -> Console:
"""Get the console."""
return self._console
@property
def group(self) -> Group:
"""Get the group."""
return self._group
@property
def tree(self) -> Tree:
"""Get the tree."""
return self._tree
@property
def live(self) -> Live:
"""Get the live."""
return self._live
def __init__(
self,
prefix: str,
parent: "RichProgressLogger | None" = None,
transient: bool = True,
) -> None:
"""Create a new rich-based progress logger."""
self._prefix = prefix
if parent is None:
console = Console()
group = Group(Spinner("dots", prefix), fit=True)
tree = Tree(group)
live = Live(
tree, console=console, refresh_per_second=1, vertical_overflow="crop"
)
live.start()
self._console = console
self._group = group
self._tree = tree
self._live = live
self._transient = False
else:
self._console = parent.console
self._group = parent.group
progress_columns = [*Progress.get_default_columns(), TimeElapsedColumn()]
self._progressbar = Progress(
*progress_columns, console=self._console, transient=transient
)
tree = Tree(prefix)
tree.add(self._progressbar)
tree.hide_root = True
if parent is not None:
parent_tree = parent.tree
parent_tree.hide_root = False
parent_tree.add(tree)
self._tree = tree
self._live = parent.live
self._transient = transient
self.refresh()
def refresh(self) -> None:
"""Perform a debounced refresh."""
now = asyncio.get_event_loop().time()
duration = now - self._last_refresh
if duration > 0.1:
self._last_refresh = now
self.force_refresh()
def force_refresh(self) -> None:
"""Force a refresh."""
self.live.refresh()
def stop(self) -> None:
"""Stop the progress logger."""
self._live.stop()
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
"""Create a child progress bar."""
return RichProgressLogger(parent=self, prefix=prefix, transient=transient)
def error(self, message: str) -> None:
"""Log an error."""
self._console.print(f"❌ [red]{message}[/red]")
def warning(self, message: str) -> None:
"""Log a warning."""
self._console.print(f"⚠️ [yellow]{message}[/yellow]")
def success(self, message: str) -> None:
"""Log success."""
self._console.print(f"🚀 [green]{message}[/green]")
def info(self, message: str) -> None:
"""Log information."""
self._console.print(message)
def __call__(self, progress_update: GRProgress) -> None:
"""Update progress."""
if self._disposing:
return
progressbar = self._progressbar
if self._task is None:
self._task = progressbar.add_task(self._prefix)
progress_description = ""
if progress_update.description is not None:
progress_description = f" - {progress_update.description}"
completed = progress_update.completed_items or progress_update.percent
total = progress_update.total_items or 1
progressbar.update(
self._task,
completed=completed,
total=total,
description=f"{self._prefix}{progress_description}",
)
if completed == total and self._transient:
progressbar.update(self._task, visible=False)
self.refresh()

View File

@ -0,0 +1,153 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Standard logging configuration for the graphrag package.
This module provides a standardized way to configure Python's built-in
logging system for use within the graphrag package.
Usage:
# Configuration should be done once at the start of your application:
from graphrag.logger.standard_logging import init_loggers
init_loggers(log_file="/path/to/app.log")
# Then throughout your code:
import logging
logger = logging.getLogger(__name__) # Use standard logging
# Use standard logging methods:
logger.debug("Debug message")
logger.info("Info message")
logger.warning("Warning message")
logger.error("Error message")
logger.critical("Critical error message")
Notes
-----
The logging system is hierarchical. Loggers are organized in a tree structure,
with the root logger named 'graphrag'. All loggers created with names starting
with 'graphrag.' will be children of this root logger. This allows for consistent
configuration of all graphrag-related logs throughout the application.
All progress logging now uses this standard logging system for consistency.
"""
import logging
import sys
from pathlib import Path
from graphrag.config.enums import ReportingType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.reporting_config import ReportingConfig
LOG_FORMAT = "%(asctime)s.%(msecs)04d - %(levelname)s - %(name)s - %(message)s"
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
def init_loggers(
config: GraphRagConfig | None = None,
root_dir: str | None = None,
verbose: bool = False,
log_file: str | Path | None = None,
) -> None:
"""Initialize logging handlers for graphrag based on configuration.
This function merges the functionality of configure_logging() and create_pipeline_logger()
to provide a unified way to set up logging for the graphrag package.
Parameters
----------
config : GraphRagConfig | None, default=None
The GraphRAG configuration. If None, defaults to file-based reporting.
root_dir : str | None, default=None
The root directory for file-based logging.
verbose : bool, default=False
Whether to enable verbose (DEBUG) logging.
log_file : Optional[Union[str, Path]], default=None
Path to a specific log file. If provided, takes precedence over config.
"""
# import BlobWorkflowLogger here to avoid circular imports
from graphrag.logger.blob_workflow_logger import BlobWorkflowLogger
# extract reporting config from GraphRagConfig if provided
reporting_config: ReportingConfig
if log_file:
# if log_file is provided directly, override config to use file-based logging
log_path = Path(log_file)
reporting_config = ReportingConfig(
type=ReportingType.file,
base_dir=str(log_path.parent),
)
elif config is not None:
# use the reporting configuration from GraphRagConfig
reporting_config = config.reporting
else:
# default to file-based logging if no config provided
reporting_config = ReportingConfig(base_dir="logs", type=ReportingType.file)
logger = logging.getLogger("graphrag")
log_level = logging.DEBUG if verbose else logging.INFO
logger.setLevel(log_level)
# clear any existing handlers to avoid duplicate logs
if logger.hasHandlers():
# Close file handlers properly before removing them
for handler in logger.handlers:
if isinstance(handler, logging.FileHandler):
handler.close()
logger.handlers.clear()
# create formatter with custom format
formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt=DATE_FORMAT)
init_console_logger(verbose)
# add more handlers based on configuration
handler: logging.Handler
match reporting_config.type:
case ReportingType.file:
if log_file:
# use the specific log file provided
log_file_path = Path(log_file)
log_file_path.parent.mkdir(parents=True, exist_ok=True)
handler = logging.FileHandler(str(log_file_path), mode="a")
else:
# use the config-based file path
log_dir = Path(root_dir or "") / (reporting_config.base_dir or "")
log_dir.mkdir(parents=True, exist_ok=True)
log_file_path = log_dir / "logs.txt"
handler = logging.FileHandler(str(log_file_path), mode="a")
handler.setFormatter(formatter)
logger.addHandler(handler)
case ReportingType.blob:
handler = BlobWorkflowLogger(
reporting_config.connection_string,
reporting_config.container_name,
base_dir=reporting_config.base_dir,
storage_account_blob_url=reporting_config.storage_account_blob_url,
)
logger.addHandler(handler)
case _:
logger.error("Unknown reporting type '%s'.", reporting_config.type)
def init_console_logger(verbose: bool = False) -> None:
"""Initialize a console logger if not already present.
This function sets up a logger that outputs log messages to STDOUT.
Parameters
----------
verbose : bool, default=False
Whether to enable verbose (DEBUG) logging.
"""
logger = logging.getLogger("graphrag")
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
has_console_handler = any(
isinstance(h, logging.StreamHandler) for h in logger.handlers
)
if not has_console_handler:
console_handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt=DATE_FORMAT)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

View File

@ -1,22 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Logging types.
This module defines the types of loggers that can be used.
"""
from enum import Enum
# Note: Code in this module was not included in the factory module because it negatively impacts the CLI experience.
class LoggerType(str, Enum):
"""The type of logger to use."""
RICH = "rich"
PRINT = "print"
NONE = "none"
def __str__(self):
"""Return a string representation of the enum value."""
return self.value

View File

@ -3,6 +3,8 @@
"""Input loading module."""
import logging
import numpy as np
import pandas as pd
@ -14,7 +16,6 @@ from graphrag.index.operations.embed_text.strategies.openai import (
run as run_embed_text,
)
from graphrag.index.workflows.create_base_text_units import create_base_text_units
from graphrag.logger.base import ProgressLogger
from graphrag.prompt_tune.defaults import (
LIMIT,
N_SUBSET_MAX,
@ -41,7 +42,7 @@ async def load_docs_in_chunks(
config: GraphRagConfig,
select_method: DocSelectionType,
limit: int,
logger: ProgressLogger,
logger: logging.Logger,
chunk_size: int,
overlap: int,
n_subset_max: int = N_SUBSET_MAX,
@ -52,7 +53,7 @@ async def load_docs_in_chunks(
config.embed_text.model_id
)
input_storage = create_storage_from_config(config.input.storage)
dataset = await create_input(config.input, input_storage, logger)
dataset = await create_input(config.input, input_storage)
chunk_config = config.chunks
chunks_df = create_base_text_units(
documents=dataset,

View File

@ -14,7 +14,7 @@ from graphrag.data_model.community_report import CommunityReport
from graphrag.data_model.entity import Entity
from graphrag.query.llm.text_utils import num_tokens
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
NO_COMMUNITY_RECORDS_WARNING: str = (
"Warning: No community records added when building community context."
@ -88,7 +88,7 @@ def build_community_context(
)
)
if compute_community_weights:
log.info("Computing community weights...")
logger.debug("Computing community weights...")
community_reports = _compute_community_weights(
community_reports=community_reports,
entities=entities,
@ -177,7 +177,7 @@ def build_community_context(
_cut_batch()
if len(all_context_records) == 0:
log.warning(NO_COMMUNITY_RECORDS_WARNING)
logger.warning(NO_COMMUNITY_RECORDS_WARNING)
return ([], {})
return all_context_text, {

View File

@ -18,7 +18,7 @@ from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
from graphrag.query.context_builder.rate_relevancy import rate_relevancy
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
class DynamicCommunitySelection:
@ -109,7 +109,7 @@ class DynamicCommunitySelection:
communities_to_rate = []
for community, result in zip(queue, gather_results, strict=True):
rating = result["rating"]
log.debug(
logger.debug(
"dynamic community selection: community %s rating %s",
community,
rating,
@ -127,7 +127,7 @@ class DynamicCommunitySelection:
if child in self.reports:
communities_to_rate.append(child)
else:
log.debug(
logger.debug(
"dynamic community selection: cannot find community %s in reports",
child,
)
@ -142,7 +142,7 @@ class DynamicCommunitySelection:
and (str(level) in self.levels)
and (level <= self.max_level)
):
log.info(
logger.debug(
"dynamic community selection: no relevant community "
"reports, adding all reports at level %s to rate.",
level,
@ -155,7 +155,7 @@ class DynamicCommunitySelection:
]
end = time()
log.info(
logger.debug(
"dynamic community selection (took: %ss)\n"
"\trating distribution %s\n"
"\t%s out of %s community reports are relevant\n"

Some files were not shown because too many files have changed in this diff Show More