mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Improve internal logging functionality by using Python's standard logging module (#1956)
Some checks are pending
gh-pages / build (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.11) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.10) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.11) (push) Waiting to run
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Publish (pypi) / Upload release to PyPI (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Spellcheck / spellcheck (push) Waiting to run
Some checks are pending
gh-pages / build (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.11) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.10) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.11) (push) Waiting to run
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Publish (pypi) / Upload release to PyPI (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Spellcheck / spellcheck (push) Waiting to run
* Initial plan for issue * Implement standard logging module and integrate with existing loggers Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Add test cases and improve documentation for standard logging Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Apply ruff formatting and add semversioner file for logging improvements Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Remove custom logger classes and refactor to use standard logging only Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Apply ruff formatting to resolve CI/CD test failures Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Add semversioner file and fix linting issues Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * ruff fixes * fix spelling error * Remove StandardProgressLogger and refactor to use standard logging Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Remove LoggerFactory and custom loggers, refactor to use standard logging Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix pyright error: use logger.info() instead of calling logger as function in cosmosdb_pipeline_storage.py Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * ruff fixes * Remove deprecated logger files that were marked as deprecated placeholders Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Replace custom get_logger with standard Python logging Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix linting issues found by ruff check --fix Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * apply ruff check fixes * add word to dictionary * Fix type checker error in ModelManager.__new__ method Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Refactor multiple logging.getLogger() calls to use single logger per file Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Remove progress_logger parameter from build_index() and logger parameter from generate_indexing_prompts() Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Remove logger parameter from run_pipeline and standardize logger naming Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Replace logger parameter with log_level parameter in CLI commands Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix import ordering in notebook files to pass poetry poe check Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Remove --logger parameter from smoke test command Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix Windows CI/CD issue with log file cleanup in tests Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Add StreamHandler to root logger in __main__.py for CLI logging Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Only add StreamHandler if root logger doesn't have existing StreamHandler Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix import ordering in notebook files to pass ruff checks Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Replace logging.StreamHandler with colorlog.StreamHandler for colorized log output Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Regenerate poetry.lock file after adding colorlog dependency Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix import ordering in notebook files to pass ruff checks Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * move printing of dataframes to debug level * remove colorlog for now * Refactor workflow callbacks to inherit from logging.Handler Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix linting issues in workflow callback handlers Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix pyright type errors in blob and file workflow callbacks Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Refactor pipeline logging to use pure logging.Handler subclasses Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Rename workflow callback classes to workflow logger classes and move to logger directory Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * update dictionary * apply ruff fixes * fix function name * simplify logger code * update * Remove error, warning, and log methods from WorkflowCallbacks and replace with standard logging Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * ruff fixes * Fix pyright errors by removing WorkflowCallbacks from strategy type signatures Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Remove ConsoleWorkflowLogger and apply consistent formatter to all handlers Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * apply ruff fixes * Refactor pipeline_logger.py to use standard FileHandler and remove FileWorkflowLogger Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Remove conditional azure import checks from blob_workflow_logger.py Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix pyright type checking errors in mock_provider.py and utils.py Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Run ruff check --fix to fix import ordering in notebooks Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Merge configure_logging and create_pipeline_logger into init_loggers function Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Remove configure_logging and create_pipeline_logger functions, replace all usage with init_loggers Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * apply ruff fixes * cleanup unused code * Update init_loggers to accept GraphRagConfig instead of ReportingConfig Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * apply ruff check fixes * Fix test failures by providing valid GraphRagConfig with required model configurations Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * apply ruff fixes * remove logging_workflow_callback * cleanup logging messages * Add logging to track progress of pandas DataFrame apply operation in create_base_text_units Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * cleanup logger logic throughout codebase * update * more cleanup of old loggers * small logger cleanup * final code cleanup and added loggers to query * add verbose logging to query * minor code cleanup * Fix broken unit tests for chunk_text and standard_logging Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * apply ruff fixes * Fix test_chunk_text by mocking progress_ticker function instead of ProgressTicker class Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * remove unnecessary logger * remove rich and fix type annotation * revert test formatting changes my by copilot * promote graphrag logs to root logger * add correct semversioner file * revert change to file * revert formatting changes that have no effect * fix changes after merge with main * revert unnecessary copilot changes * remove whitespace * cleanup docstring * simplify some logic with less code * update poetry lock file * ruff fixes --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> Co-authored-by: Josh Bradley <joshbradley@microsoft.com>
This commit is contained in:
parent
27c6de846f
commit
e84df28e64
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "cleaned up logging to follow python standards."
|
||||
}
|
||||
@ -102,6 +102,7 @@ itertuples
|
||||
isin
|
||||
nocache
|
||||
nbconvert
|
||||
levelno
|
||||
|
||||
# HTML
|
||||
nbsp
|
||||
@ -186,6 +187,7 @@ Verdantis's
|
||||
# English
|
||||
skippable
|
||||
upvote
|
||||
unconfigured
|
||||
|
||||
# Misc
|
||||
Arxiv
|
||||
|
||||
@ -178,11 +178,11 @@ This section controls the cache mechanism used by the pipeline. This is used to
|
||||
|
||||
### Reporting
|
||||
|
||||
This section controls the reporting mechanism used by the pipeline, for common events and error messages. The default is to write reports to a file in the output directory. However, you can also choose to write reports to the console or to an Azure Blob Storage container.
|
||||
This section controls the reporting mechanism used by the pipeline, for common events and error messages. The default is to write reports to a file in the output directory. However, you can also choose to write reports to an Azure Blob Storage container.
|
||||
|
||||
| Parameter | Description | Type | Required or Optional | Default |
|
||||
| --------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- | -------------------- | ------- |
|
||||
| `GRAPHRAG_REPORTING_TYPE` | The type of reporter to use. Options are `file`, `console`, or `blob` | `str` | optional | `file` |
|
||||
| `GRAPHRAG_REPORTING_TYPE` | The type of reporter to use. Options are `file` or `blob` | `str` | optional | `file` |
|
||||
| `GRAPHRAG_REPORTING_STORAGE_ACCOUNT_BLOB_URL` | The Azure Storage blob endpoint to use when in `blob` mode and using managed identity. Will have the format `https://<storage_account_name>.blob.core.windows.net` | `str` | optional | None |
|
||||
| `GRAPHRAG_REPORTING_CONNECTION_STRING` | The Azure Storage connection string to use when in `blob` mode. | `str` | optional | None |
|
||||
| `GRAPHRAG_REPORTING_CONTAINER_NAME` | The Azure Storage container name to use when in `blob` mode. | `str` | optional | None |
|
||||
|
||||
@ -149,11 +149,11 @@ This section controls the cache mechanism used by the pipeline. This is used to
|
||||
|
||||
### reporting
|
||||
|
||||
This section controls the reporting mechanism used by the pipeline, for common events and error messages. The default is to write reports to a file in the output directory. However, you can also choose to write reports to the console or to an Azure Blob Storage container.
|
||||
This section controls the reporting mechanism used by the pipeline, for common events and error messages. The default is to write reports to a file in the output directory. However, you can also choose to write reports to an Azure Blob Storage container.
|
||||
|
||||
#### Fields
|
||||
|
||||
- `type` **file|console|blob** - The reporting type to use. Default=`file`
|
||||
- `type` **file|blob** - The reporting type to use. Default=`file`
|
||||
- `base_dir` **str** - The base directory to write reports to, relative to the root.
|
||||
- `connection_string` **str** - (blob only) The Azure Storage connection string.
|
||||
- `container_name` **str** - (blob only) The Azure Storage container name.
|
||||
|
||||
@ -2,3 +2,10 @@
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The GraphRAG package."""
|
||||
|
||||
import logging
|
||||
|
||||
from graphrag.logger.standard_logging import init_console_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
init_console_logger()
|
||||
|
||||
@ -10,7 +10,7 @@ Backwards compatibility is not guaranteed at this time.
|
||||
|
||||
import logging
|
||||
|
||||
from graphrag.callbacks.reporting import create_pipeline_reporter
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import IndexingMethod
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
@ -19,10 +19,9 @@ from graphrag.index.run.utils import create_callback_chain
|
||||
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.index.typing.workflow import WorkflowFunction
|
||||
from graphrag.index.workflows.factory import PipelineFactory
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.null_progress import NullProgressLogger
|
||||
from graphrag.logger.standard_logging import init_loggers
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def build_index(
|
||||
@ -31,7 +30,6 @@ async def build_index(
|
||||
is_update_run: bool = False,
|
||||
memory_profile: bool = False,
|
||||
callbacks: list[WorkflowCallbacks] | None = None,
|
||||
progress_logger: ProgressLogger | None = None,
|
||||
) -> list[PipelineRunResult]:
|
||||
"""Run the pipeline with the given configuration.
|
||||
|
||||
@ -45,26 +43,25 @@ async def build_index(
|
||||
Whether to enable memory profiling.
|
||||
callbacks : list[WorkflowCallbacks] | None default=None
|
||||
A list of callbacks to register.
|
||||
progress_logger : ProgressLogger | None default=None
|
||||
The progress logger.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[PipelineRunResult]
|
||||
The list of pipeline run results
|
||||
"""
|
||||
logger = progress_logger or NullProgressLogger()
|
||||
# create a pipeline reporter and add to any additional callbacks
|
||||
callbacks = callbacks or []
|
||||
callbacks.append(create_pipeline_reporter(config.reporting, None))
|
||||
init_loggers(config=config)
|
||||
|
||||
workflow_callbacks = create_callback_chain(callbacks, logger)
|
||||
# Create callbacks for pipeline lifecycle events if provided
|
||||
workflow_callbacks = (
|
||||
create_callback_chain(callbacks) if callbacks else NoopWorkflowCallbacks()
|
||||
)
|
||||
|
||||
outputs: list[PipelineRunResult] = []
|
||||
|
||||
if memory_profile:
|
||||
log.warning("New pipeline does not yet support memory profiling.")
|
||||
logger.warning("New pipeline does not yet support memory profiling.")
|
||||
|
||||
logger.info("Initializing indexing pipeline...")
|
||||
# todo: this could propagate out to the cli for better clarity, but will be a breaking api change
|
||||
method = _get_method(method, is_update_run)
|
||||
pipeline = PipelineFactory.create_pipeline(config, method)
|
||||
@ -75,15 +72,14 @@ async def build_index(
|
||||
pipeline,
|
||||
config,
|
||||
callbacks=workflow_callbacks,
|
||||
logger=logger,
|
||||
is_update_run=is_update_run,
|
||||
):
|
||||
outputs.append(output)
|
||||
if output.errors and len(output.errors) > 0:
|
||||
logger.error(output.workflow)
|
||||
logger.error("Workflow %s completed with errors", output.workflow)
|
||||
else:
|
||||
logger.success(output.workflow)
|
||||
logger.info(str(output.result))
|
||||
logger.info("Workflow %s completed successfully", output.workflow)
|
||||
logger.debug(str(output.result))
|
||||
|
||||
workflow_callbacks.pipeline_end(outputs)
|
||||
return outputs
|
||||
|
||||
@ -11,6 +11,7 @@ WARNING: This API is under development and may undergo changes in future release
|
||||
Backwards compatibility is not guaranteed at this time.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
import annotated_types
|
||||
@ -20,7 +21,7 @@ from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.standard_logging import init_loggers
|
||||
from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT, PROMPT_TUNING_MODEL_ID
|
||||
from graphrag.prompt_tune.generator.community_report_rating import (
|
||||
generate_community_report_rating,
|
||||
@ -47,11 +48,12 @@ from graphrag.prompt_tune.generator.persona import generate_persona
|
||||
from graphrag.prompt_tune.loader.input import load_docs_in_chunks
|
||||
from graphrag.prompt_tune.types import DocSelectionType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
async def generate_indexing_prompts(
|
||||
config: GraphRagConfig,
|
||||
logger: ProgressLogger,
|
||||
chunk_size: PositiveInt = graphrag_config_defaults.chunks.size,
|
||||
overlap: Annotated[
|
||||
int, annotated_types.Gt(-1)
|
||||
@ -71,8 +73,6 @@ async def generate_indexing_prompts(
|
||||
Parameters
|
||||
----------
|
||||
- config: The GraphRag configuration.
|
||||
- logger: The logger to use for progress updates.
|
||||
- root: The root directory.
|
||||
- output_path: The path to store the prompts.
|
||||
- chunk_size: The chunk token size to use for input text units.
|
||||
- limit: The limit of chunks to load.
|
||||
@ -89,6 +89,8 @@ async def generate_indexing_prompts(
|
||||
-------
|
||||
tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt
|
||||
"""
|
||||
init_loggers(config=config)
|
||||
|
||||
# Retrieve documents
|
||||
logger.info("Chunking documents...")
|
||||
doc_list = await load_docs_in_chunks(
|
||||
@ -187,9 +189,9 @@ async def generate_indexing_prompts(
|
||||
language=language,
|
||||
)
|
||||
|
||||
logger.info(f"\nGenerated domain: {domain}") # noqa: G004
|
||||
logger.info(f"\nDetected language: {language}") # noqa: G004
|
||||
logger.info(f"\nGenerated persona: {persona}") # noqa: G004
|
||||
logger.debug("Generated domain: %s", domain)
|
||||
logger.debug("Detected language: %s", language)
|
||||
logger.debug("Generated persona: %s", persona)
|
||||
|
||||
return (
|
||||
extract_graph_prompt,
|
||||
|
||||
@ -17,6 +17,7 @@ WARNING: This API is under development and may undergo changes in future release
|
||||
Backwards compatibility is not guaranteed at this time.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
@ -31,7 +32,7 @@ from graphrag.config.embeddings import (
|
||||
text_unit_text_embedding,
|
||||
)
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.logger.print_progress import PrintProgressLogger
|
||||
from graphrag.logger.standard_logging import init_loggers
|
||||
from graphrag.query.factory import (
|
||||
get_basic_search_engine,
|
||||
get_drift_search_engine,
|
||||
@ -50,11 +51,13 @@ from graphrag.query.indexer_adapters import (
|
||||
from graphrag.utils.api import (
|
||||
get_embedding_store,
|
||||
load_search_prompt,
|
||||
truncate,
|
||||
update_context_data,
|
||||
)
|
||||
from graphrag.utils.cli import redact
|
||||
|
||||
logger = PrintProgressLogger("")
|
||||
# Initialize standard logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
@ -68,6 +71,7 @@ async def global_search(
|
||||
response_type: str,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
||||
@ -88,11 +92,9 @@ async def global_search(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
|
||||
callbacks = callbacks or []
|
||||
full_response = ""
|
||||
context_data = {}
|
||||
@ -105,6 +107,7 @@ async def global_search(
|
||||
local_callbacks.on_context = on_context
|
||||
callbacks.append(local_callbacks)
|
||||
|
||||
logger.debug("Executing global search query: %s", query)
|
||||
async for chunk in global_search_streaming(
|
||||
config=config,
|
||||
entities=entities,
|
||||
@ -117,6 +120,7 @@ async def global_search(
|
||||
callbacks=callbacks,
|
||||
):
|
||||
full_response += chunk
|
||||
logger.debug("Query response: %s", truncate(full_response, 400))
|
||||
return full_response, context_data
|
||||
|
||||
|
||||
@ -131,6 +135,7 @@ def global_search_streaming(
|
||||
response_type: str,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> AsyncGenerator:
|
||||
"""Perform a global search and return the context data and response via a generator.
|
||||
|
||||
@ -150,11 +155,9 @@ def global_search_streaming(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
|
||||
communities_ = read_indexer_communities(communities, community_reports)
|
||||
reports = read_indexer_reports(
|
||||
community_reports,
|
||||
@ -173,6 +176,7 @@ def global_search_streaming(
|
||||
config.root_dir, config.global_search.knowledge_prompt
|
||||
)
|
||||
|
||||
logger.debug("Executing streaming global search query: %s", query)
|
||||
search_engine = get_global_search_engine(
|
||||
config,
|
||||
reports=reports,
|
||||
@ -201,6 +205,7 @@ async def multi_index_global_search(
|
||||
streaming: bool,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
||||
@ -223,11 +228,9 @@ async def multi_index_global_search(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
message = "Streaming not yet implemented for multi_global_search"
|
||||
@ -311,6 +314,7 @@ async def multi_index_global_search(
|
||||
communities_dfs, axis=0, ignore_index=True, sort=False
|
||||
)
|
||||
|
||||
logger.debug("Executing multi-index global search query: %s", query)
|
||||
result = await global_search(
|
||||
config,
|
||||
entities=entities_combined,
|
||||
@ -326,6 +330,7 @@ async def multi_index_global_search(
|
||||
# Update the context data by linking index names and community ids
|
||||
context = update_context_data(result[1], links)
|
||||
|
||||
logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore
|
||||
return (result[0], context)
|
||||
|
||||
|
||||
@ -342,6 +347,7 @@ async def local_search(
|
||||
response_type: str,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
||||
@ -362,11 +368,9 @@ async def local_search(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
|
||||
callbacks = callbacks or []
|
||||
full_response = ""
|
||||
context_data = {}
|
||||
@ -379,6 +383,7 @@ async def local_search(
|
||||
local_callbacks.on_context = on_context
|
||||
callbacks.append(local_callbacks)
|
||||
|
||||
logger.debug("Executing local search query: %s", query)
|
||||
async for chunk in local_search_streaming(
|
||||
config=config,
|
||||
entities=entities,
|
||||
@ -393,6 +398,7 @@ async def local_search(
|
||||
callbacks=callbacks,
|
||||
):
|
||||
full_response += chunk
|
||||
logger.debug("Query response: %s", truncate(full_response, 400))
|
||||
return full_response, context_data
|
||||
|
||||
|
||||
@ -409,6 +415,7 @@ def local_search_streaming(
|
||||
response_type: str,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> AsyncGenerator:
|
||||
"""Perform a local search and return the context data and response via a generator.
|
||||
|
||||
@ -427,16 +434,14 @@ def local_search_streaming(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
vector_store_args[index] = store.model_dump()
|
||||
msg = f"Vector Store Args: {redact(vector_store_args)}"
|
||||
logger.info(msg)
|
||||
logger.debug(msg)
|
||||
|
||||
description_embedding_store = get_embedding_store(
|
||||
config_args=vector_store_args,
|
||||
@ -447,6 +452,7 @@ def local_search_streaming(
|
||||
covariates_ = read_indexer_covariates(covariates) if covariates is not None else []
|
||||
prompt = load_search_prompt(config.root_dir, config.local_search.prompt)
|
||||
|
||||
logger.debug("Executing streaming local search query: %s", query)
|
||||
search_engine = get_local_search_engine(
|
||||
config=config,
|
||||
reports=read_indexer_reports(community_reports, communities, community_level),
|
||||
@ -477,6 +483,7 @@ async def multi_index_local_search(
|
||||
streaming: bool,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
||||
@ -500,11 +507,9 @@ async def multi_index_local_search(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
message = "Streaming not yet implemented for multi_index_local_search"
|
||||
@ -670,6 +675,7 @@ async def multi_index_local_search(
|
||||
covariates_combined = pd.concat(
|
||||
covariates_dfs, axis=0, ignore_index=True, sort=False
|
||||
)
|
||||
logger.debug("Executing multi-index local search query: %s", query)
|
||||
result = await local_search(
|
||||
config,
|
||||
entities=entities_combined,
|
||||
@ -687,6 +693,7 @@ async def multi_index_local_search(
|
||||
# Update the context data by linking index names and community ids
|
||||
context = update_context_data(result[1], links)
|
||||
|
||||
logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore
|
||||
return (result[0], context)
|
||||
|
||||
|
||||
@ -702,6 +709,7 @@ async def drift_search(
|
||||
response_type: str,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
||||
@ -721,11 +729,9 @@ async def drift_search(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
|
||||
callbacks = callbacks or []
|
||||
full_response = ""
|
||||
context_data = {}
|
||||
@ -738,6 +744,7 @@ async def drift_search(
|
||||
local_callbacks.on_context = on_context
|
||||
callbacks.append(local_callbacks)
|
||||
|
||||
logger.debug("Executing drift search query: %s", query)
|
||||
async for chunk in drift_search_streaming(
|
||||
config=config,
|
||||
entities=entities,
|
||||
@ -751,6 +758,7 @@ async def drift_search(
|
||||
callbacks=callbacks,
|
||||
):
|
||||
full_response += chunk
|
||||
logger.debug("Query response: %s", truncate(full_response, 400))
|
||||
return full_response, context_data
|
||||
|
||||
|
||||
@ -766,6 +774,7 @@ def drift_search_streaming(
|
||||
response_type: str,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> AsyncGenerator:
|
||||
"""Perform a DRIFT search and return the context data and response.
|
||||
|
||||
@ -782,16 +791,14 @@ def drift_search_streaming(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
vector_store_args[index] = store.model_dump()
|
||||
msg = f"Vector Store Args: {redact(vector_store_args)}"
|
||||
logger.info(msg)
|
||||
logger.debug(msg)
|
||||
|
||||
description_embedding_store = get_embedding_store(
|
||||
config_args=vector_store_args,
|
||||
@ -811,6 +818,7 @@ def drift_search_streaming(
|
||||
config.root_dir, config.drift_search.reduce_prompt
|
||||
)
|
||||
|
||||
logger.debug("Executing streaming drift search query: %s", query)
|
||||
search_engine = get_drift_search_engine(
|
||||
config=config,
|
||||
reports=reports,
|
||||
@ -840,6 +848,7 @@ async def multi_index_drift_search(
|
||||
streaming: bool,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
||||
@ -862,11 +871,9 @@ async def multi_index_drift_search(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
message = "Streaming not yet implemented for multi_drift_search"
|
||||
@ -1009,6 +1016,7 @@ async def multi_index_drift_search(
|
||||
text_units_dfs, axis=0, ignore_index=True, sort=False
|
||||
)
|
||||
|
||||
logger.debug("Executing multi-index drift search query: %s", query)
|
||||
result = await drift_search(
|
||||
config,
|
||||
entities=entities_combined,
|
||||
@ -1029,6 +1037,8 @@ async def multi_index_drift_search(
|
||||
context[key] = update_context_data(result[1][key], links)
|
||||
else:
|
||||
context = result[1]
|
||||
|
||||
logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore
|
||||
return (result[0], context)
|
||||
|
||||
|
||||
@ -1038,6 +1048,7 @@ async def basic_search(
|
||||
text_units: pd.DataFrame,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
||||
@ -1053,11 +1064,9 @@ async def basic_search(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
|
||||
callbacks = callbacks or []
|
||||
full_response = ""
|
||||
context_data = {}
|
||||
@ -1070,6 +1079,7 @@ async def basic_search(
|
||||
local_callbacks.on_context = on_context
|
||||
callbacks.append(local_callbacks)
|
||||
|
||||
logger.debug("Executing basic search query: %s", query)
|
||||
async for chunk in basic_search_streaming(
|
||||
config=config,
|
||||
text_units=text_units,
|
||||
@ -1077,6 +1087,7 @@ async def basic_search(
|
||||
callbacks=callbacks,
|
||||
):
|
||||
full_response += chunk
|
||||
logger.debug("Query response: %s", truncate(full_response, 400))
|
||||
return full_response, context_data
|
||||
|
||||
|
||||
@ -1086,6 +1097,7 @@ def basic_search_streaming(
|
||||
text_units: pd.DataFrame,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> AsyncGenerator:
|
||||
"""Perform a local search and return the context data and response via a generator.
|
||||
|
||||
@ -1098,16 +1110,14 @@ def basic_search_streaming(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
vector_store_args[index] = store.model_dump()
|
||||
msg = f"Vector Store Args: {redact(vector_store_args)}"
|
||||
logger.info(msg)
|
||||
logger.debug(msg)
|
||||
|
||||
description_embedding_store = get_embedding_store(
|
||||
config_args=vector_store_args,
|
||||
@ -1116,6 +1126,7 @@ def basic_search_streaming(
|
||||
|
||||
prompt = load_search_prompt(config.root_dir, config.basic_search.prompt)
|
||||
|
||||
logger.debug("Executing streaming basic search query: %s", query)
|
||||
search_engine = get_basic_search_engine(
|
||||
config=config,
|
||||
text_units=read_indexer_text_units(text_units),
|
||||
@ -1134,6 +1145,7 @@ async def multi_index_basic_search(
|
||||
streaming: bool,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
||||
@ -1151,11 +1163,9 @@ async def multi_index_basic_search(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
message = "Streaming not yet implemented for multi_basic_search"
|
||||
@ -1192,6 +1202,7 @@ async def multi_index_basic_search(
|
||||
text_units_dfs, axis=0, ignore_index=True, sort=False
|
||||
)
|
||||
|
||||
logger.debug("Executing multi-index basic search query: %s", query)
|
||||
return await basic_search(
|
||||
config,
|
||||
text_units=text_units_combined,
|
||||
|
||||
@ -1,32 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A logger that emits updates from the indexing engine to the console."""
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
|
||||
|
||||
class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
"""A logger that writes to a console."""
|
||||
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
stack: str | None = None,
|
||||
details: dict | None = None,
|
||||
):
|
||||
"""Handle when an error occurs."""
|
||||
print(message, str(cause), stack, details) # noqa T201
|
||||
|
||||
def warning(self, message: str, details: dict | None = None):
|
||||
"""Handle when a warning occurs."""
|
||||
_print_warning(message)
|
||||
|
||||
def log(self, message: str, details: dict | None = None):
|
||||
"""Handle when a log message is produced."""
|
||||
print(message, details) # noqa T201
|
||||
|
||||
|
||||
def _print_warning(skk):
|
||||
print("\033[93m {}\033[00m".format(skk)) # noqa T201
|
||||
@ -1,78 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A logger that emits updates from the indexing engine to a local file."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
"""A logger that writes to a local file."""
|
||||
|
||||
_out_stream: TextIOWrapper
|
||||
|
||||
def __init__(self, directory: str):
|
||||
"""Create a new file-based workflow logger."""
|
||||
Path(directory).mkdir(parents=True, exist_ok=True)
|
||||
self._out_stream = open( # noqa: PTH123, SIM115
|
||||
Path(directory) / "logs.json", "a", encoding="utf-8", errors="strict"
|
||||
)
|
||||
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
stack: str | None = None,
|
||||
details: dict | None = None,
|
||||
):
|
||||
"""Handle when an error occurs."""
|
||||
self._out_stream.write(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"data": message,
|
||||
"stack": stack,
|
||||
"source": str(cause),
|
||||
"details": details,
|
||||
},
|
||||
indent=4,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
message = f"{message} details={details}"
|
||||
log.info(message)
|
||||
|
||||
def warning(self, message: str, details: dict | None = None):
|
||||
"""Handle when a warning occurs."""
|
||||
self._out_stream.write(
|
||||
json.dumps(
|
||||
{"type": "warning", "data": message, "details": details},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
_print_warning(message)
|
||||
|
||||
def log(self, message: str, details: dict | None = None):
|
||||
"""Handle when a log message is produced."""
|
||||
self._out_stream.write(
|
||||
json.dumps(
|
||||
{"type": "log", "data": message, "details": details}, ensure_ascii=False
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
message = f"{message} details={details}"
|
||||
log.info(message)
|
||||
|
||||
|
||||
def _print_warning(skk):
|
||||
log.warning(skk)
|
||||
@ -9,13 +9,13 @@ from graphrag.logger.progress import Progress
|
||||
|
||||
|
||||
class NoopWorkflowCallbacks(WorkflowCallbacks):
|
||||
"""A no-op implementation of WorkflowCallbacks."""
|
||||
"""A no-op implementation of WorkflowCallbacks that logs all events to standard logging."""
|
||||
|
||||
def pipeline_start(self, names: list[str]) -> None:
|
||||
"""Execute this callback when a the entire pipeline starts."""
|
||||
"""Execute this callback to signal when the entire pipeline starts."""
|
||||
|
||||
def pipeline_end(self, results: list[PipelineRunResult]) -> None:
|
||||
"""Execute this callback when the entire pipeline ends."""
|
||||
"""Execute this callback to signal when the entire pipeline ends."""
|
||||
|
||||
def workflow_start(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow starts."""
|
||||
@ -25,18 +25,3 @@ class NoopWorkflowCallbacks(WorkflowCallbacks):
|
||||
|
||||
def progress(self, progress: Progress) -> None:
|
||||
"""Handle when progress occurs."""
|
||||
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
stack: str | None = None,
|
||||
details: dict | None = None,
|
||||
) -> None:
|
||||
"""Handle when an error occurs."""
|
||||
|
||||
def warning(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a warning occurs."""
|
||||
|
||||
def log(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a log message occurs."""
|
||||
|
||||
@ -1,42 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A workflow callback manager that emits updates."""
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.progress import Progress
|
||||
|
||||
|
||||
class ProgressWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
"""A callbackmanager that delegates to a ProgressLogger."""
|
||||
|
||||
_root_progress: ProgressLogger
|
||||
_progress_stack: list[ProgressLogger]
|
||||
|
||||
def __init__(self, progress: ProgressLogger) -> None:
|
||||
"""Create a new ProgressWorkflowCallbacks."""
|
||||
self._progress = progress
|
||||
self._progress_stack = [progress]
|
||||
|
||||
def _pop(self) -> None:
|
||||
self._progress_stack.pop()
|
||||
|
||||
def _push(self, name: str) -> None:
|
||||
self._progress_stack.append(self._latest.child(name))
|
||||
|
||||
@property
|
||||
def _latest(self) -> ProgressLogger:
|
||||
return self._progress_stack[-1]
|
||||
|
||||
def workflow_start(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow starts."""
|
||||
self._push(name)
|
||||
|
||||
def workflow_end(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow ends."""
|
||||
self._pop()
|
||||
|
||||
def progress(self, progress: Progress) -> None:
|
||||
"""Handle when progress occurs."""
|
||||
self._latest(progress)
|
||||
@ -1,39 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing the pipeline reporter factory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from graphrag.callbacks.blob_workflow_callbacks import BlobWorkflowCallbacks
|
||||
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
|
||||
from graphrag.callbacks.file_workflow_callbacks import FileWorkflowCallbacks
|
||||
from graphrag.config.enums import ReportingType
|
||||
from graphrag.config.models.reporting_config import ReportingConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
|
||||
|
||||
def create_pipeline_reporter(
|
||||
config: ReportingConfig | None, root_dir: str | None
|
||||
) -> WorkflowCallbacks:
|
||||
"""Create a logger for the given pipeline config."""
|
||||
config = config or ReportingConfig(base_dir="logs", type=ReportingType.file)
|
||||
match config.type:
|
||||
case ReportingType.file:
|
||||
return FileWorkflowCallbacks(
|
||||
str(Path(root_dir or "") / (config.base_dir or ""))
|
||||
)
|
||||
case ReportingType.console:
|
||||
return ConsoleWorkflowCallbacks()
|
||||
case ReportingType.blob:
|
||||
return BlobWorkflowCallbacks(
|
||||
config.connection_string,
|
||||
config.container_name,
|
||||
base_dir=config.base_dir,
|
||||
storage_account_blob_url=config.storage_account_blob_url,
|
||||
)
|
||||
@ -35,21 +35,3 @@ class WorkflowCallbacks(Protocol):
|
||||
def progress(self, progress: Progress) -> None:
|
||||
"""Handle when progress occurs."""
|
||||
...
|
||||
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
stack: str | None = None,
|
||||
details: dict | None = None,
|
||||
) -> None:
|
||||
"""Handle when an error occurs."""
|
||||
...
|
||||
|
||||
def warning(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a warning occurs."""
|
||||
...
|
||||
|
||||
def log(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a log message occurs."""
|
||||
...
|
||||
|
||||
@ -50,27 +50,3 @@ class WorkflowCallbacksManager(WorkflowCallbacks):
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "progress"):
|
||||
callback.progress(progress)
|
||||
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
stack: str | None = None,
|
||||
details: dict | None = None,
|
||||
) -> None:
|
||||
"""Handle when an error occurs."""
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "error"):
|
||||
callback.error(message, cause, stack, details)
|
||||
|
||||
def warning(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a warning occurs."""
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "warning"):
|
||||
callback.warning(message, details)
|
||||
|
||||
def log(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a log message occurs."""
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "log"):
|
||||
callback.log(message, details)
|
||||
|
||||
@ -10,49 +10,26 @@ import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import graphrag.api as api
|
||||
from graphrag.config.enums import CacheType, IndexingMethod
|
||||
from graphrag.config.enums import CacheType, IndexingMethod, ReportingType
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.config.logging import enable_logging_with_config
|
||||
from graphrag.index.validate_config import validate_config_names
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.factory import LoggerFactory, LoggerType
|
||||
from graphrag.utils.cli import redact
|
||||
|
||||
# Ignore warnings from numba
|
||||
warnings.filterwarnings("ignore", message=".*NumbaDeprecationWarning.*")
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _logger(logger: ProgressLogger):
|
||||
def info(msg: str, verbose: bool = False):
|
||||
log.info(msg)
|
||||
if verbose:
|
||||
logger.info(msg)
|
||||
|
||||
def error(msg: str, verbose: bool = False):
|
||||
log.error(msg)
|
||||
if verbose:
|
||||
logger.error(msg)
|
||||
|
||||
def success(msg: str, verbose: bool = False):
|
||||
log.info(msg)
|
||||
if verbose:
|
||||
logger.success(msg)
|
||||
|
||||
return info, error, success
|
||||
|
||||
|
||||
def _register_signal_handlers(logger: ProgressLogger):
|
||||
def _register_signal_handlers():
|
||||
import signal
|
||||
|
||||
def handle_signal(signum, _):
|
||||
# Handle the signal here
|
||||
logger.info(f"Received signal {signum}, exiting...") # noqa: G004
|
||||
logger.dispose()
|
||||
logger.debug(f"Received signal {signum}, exiting...") # noqa: G004
|
||||
for task in asyncio.all_tasks():
|
||||
task.cancel()
|
||||
logger.info("All tasks cancelled. Exiting...")
|
||||
logger.debug("All tasks cancelled. Exiting...")
|
||||
|
||||
# Register signal handlers for SIGINT and SIGHUP
|
||||
signal.signal(signal.SIGINT, handle_signal)
|
||||
@ -67,7 +44,6 @@ def index_cli(
|
||||
verbose: bool,
|
||||
memprofile: bool,
|
||||
cache: bool,
|
||||
logger: LoggerType,
|
||||
config_filepath: Path | None,
|
||||
dry_run: bool,
|
||||
skip_validation: bool,
|
||||
@ -87,7 +63,6 @@ def index_cli(
|
||||
verbose=verbose,
|
||||
memprofile=memprofile,
|
||||
cache=cache,
|
||||
logger=logger,
|
||||
dry_run=dry_run,
|
||||
skip_validation=skip_validation,
|
||||
)
|
||||
@ -99,7 +74,6 @@ def update_cli(
|
||||
verbose: bool,
|
||||
memprofile: bool,
|
||||
cache: bool,
|
||||
logger: LoggerType,
|
||||
config_filepath: Path | None,
|
||||
skip_validation: bool,
|
||||
output_dir: Path | None,
|
||||
@ -120,7 +94,6 @@ def update_cli(
|
||||
verbose=verbose,
|
||||
memprofile=memprofile,
|
||||
cache=cache,
|
||||
logger=logger,
|
||||
dry_run=False,
|
||||
skip_validation=skip_validation,
|
||||
)
|
||||
@ -133,39 +106,47 @@ def _run_index(
|
||||
verbose,
|
||||
memprofile,
|
||||
cache,
|
||||
logger,
|
||||
dry_run,
|
||||
skip_validation,
|
||||
):
|
||||
progress_logger = LoggerFactory().create_logger(logger)
|
||||
info, error, success = _logger(progress_logger)
|
||||
# Configure the root logger with the specified log level
|
||||
from graphrag.logger.standard_logging import init_loggers
|
||||
|
||||
# Initialize loggers and reporting config
|
||||
init_loggers(
|
||||
config=config,
|
||||
root_dir=str(config.root_dir) if config.root_dir else None,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
if not cache:
|
||||
config.cache.type = CacheType.none
|
||||
|
||||
enabled_logging, log_path = enable_logging_with_config(config, verbose)
|
||||
if enabled_logging:
|
||||
info(f"Logging enabled at {log_path}", True)
|
||||
# Log the configuration details
|
||||
if config.reporting.type == ReportingType.file:
|
||||
log_dir = Path(config.root_dir or "") / (config.reporting.base_dir or "")
|
||||
log_path = log_dir / "logs.txt"
|
||||
logger.info("Logging enabled at %s", log_path)
|
||||
else:
|
||||
info(
|
||||
f"Logging not enabled for config {redact(config.model_dump())}",
|
||||
True,
|
||||
logger.info(
|
||||
"Logging not enabled for config %s",
|
||||
redact(config.model_dump()),
|
||||
)
|
||||
|
||||
if not skip_validation:
|
||||
validate_config_names(progress_logger, config)
|
||||
validate_config_names(config)
|
||||
|
||||
info(f"Starting pipeline run. {dry_run=}", verbose)
|
||||
info(
|
||||
f"Using default configuration: {redact(config.model_dump())}",
|
||||
verbose,
|
||||
logger.info("Starting pipeline run. %s", dry_run)
|
||||
logger.info(
|
||||
"Using default configuration: %s",
|
||||
redact(config.model_dump()),
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
info("Dry run complete, exiting...", True)
|
||||
logger.info("Dry run complete, exiting...", True)
|
||||
sys.exit(0)
|
||||
|
||||
_register_signal_handlers(progress_logger)
|
||||
_register_signal_handlers()
|
||||
|
||||
outputs = asyncio.run(
|
||||
api.build_index(
|
||||
@ -173,19 +154,17 @@ def _run_index(
|
||||
method=method,
|
||||
is_update_run=is_update_run,
|
||||
memory_profile=memprofile,
|
||||
progress_logger=progress_logger,
|
||||
)
|
||||
)
|
||||
encountered_errors = any(
|
||||
output.errors and len(output.errors) > 0 for output in outputs
|
||||
)
|
||||
|
||||
progress_logger.stop()
|
||||
if encountered_errors:
|
||||
error(
|
||||
"Errors occurred during the pipeline run, see logs for more details.", True
|
||||
logger.error(
|
||||
"Errors occurred during the pipeline run, see logs for more details."
|
||||
)
|
||||
else:
|
||||
success("All workflows completed successfully.", True)
|
||||
logger.info("All workflows completed successfully.")
|
||||
|
||||
sys.exit(1 if encountered_errors else 0)
|
||||
|
||||
@ -3,10 +3,10 @@
|
||||
|
||||
"""CLI implementation of the initialization subcommand."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.config.init_content import INIT_DOTENV, INIT_YAML
|
||||
from graphrag.logger.factory import LoggerFactory, LoggerType
|
||||
from graphrag.prompts.index.community_report import (
|
||||
COMMUNITY_REPORT_PROMPT,
|
||||
)
|
||||
@ -31,6 +31,8 @@ from graphrag.prompts.query.global_search_reduce_system_prompt import (
|
||||
from graphrag.prompts.query.local_search_system_prompt import LOCAL_SEARCH_SYSTEM_PROMPT
|
||||
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_project_at(path: Path, force: bool) -> None:
|
||||
"""
|
||||
@ -48,8 +50,7 @@ def initialize_project_at(path: Path, force: bool) -> None:
|
||||
ValueError
|
||||
If the project already exists and force is False.
|
||||
"""
|
||||
progress_logger = LoggerFactory().create_logger(LoggerType.RICH)
|
||||
progress_logger.info(f"Initializing project at {path}") # noqa: G004
|
||||
logger.info("Initializing project at %s", path)
|
||||
root = Path(path)
|
||||
if not root.exists():
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -12,7 +12,6 @@ import typer
|
||||
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import IndexingMethod, SearchMethod
|
||||
from graphrag.logger.types import LoggerType
|
||||
from graphrag.prompt_tune.defaults import LIMIT, MAX_TOKEN_COUNT, N_SUBSET_MAX, K
|
||||
from graphrag.prompt_tune.types import DocSelectionType
|
||||
|
||||
@ -157,11 +156,6 @@ def _index_cli(
|
||||
"--memprofile",
|
||||
help="Run the indexing pipeline with memory profiling",
|
||||
),
|
||||
logger: LoggerType = typer.Option(
|
||||
LoggerType.RICH.value,
|
||||
"--logger",
|
||||
help="The progress logger to use.",
|
||||
),
|
||||
dry_run: bool = typer.Option(
|
||||
False,
|
||||
"--dry-run",
|
||||
@ -201,7 +195,6 @@ def _index_cli(
|
||||
verbose=verbose,
|
||||
memprofile=memprofile,
|
||||
cache=cache,
|
||||
logger=LoggerType(logger),
|
||||
config_filepath=config,
|
||||
dry_run=dry_run,
|
||||
skip_validation=skip_validation,
|
||||
@ -250,11 +243,6 @@ def _update_cli(
|
||||
"--memprofile",
|
||||
help="Run the indexing pipeline with memory profiling.",
|
||||
),
|
||||
logger: LoggerType = typer.Option(
|
||||
LoggerType.RICH.value,
|
||||
"--logger",
|
||||
help="The progress logger to use.",
|
||||
),
|
||||
cache: bool = typer.Option(
|
||||
True,
|
||||
"--cache/--no-cache",
|
||||
@ -290,7 +278,6 @@ def _update_cli(
|
||||
verbose=verbose,
|
||||
memprofile=memprofile,
|
||||
cache=cache,
|
||||
logger=LoggerType(logger),
|
||||
config_filepath=config,
|
||||
skip_validation=skip_validation,
|
||||
output_dir=output,
|
||||
@ -327,11 +314,6 @@ def _prompt_tune_cli(
|
||||
"-v",
|
||||
help="Run the prompt tuning pipeline with verbose logging.",
|
||||
),
|
||||
logger: LoggerType = typer.Option(
|
||||
LoggerType.RICH.value,
|
||||
"--logger",
|
||||
help="The progress logger to use.",
|
||||
),
|
||||
domain: str | None = typer.Option(
|
||||
None,
|
||||
"--domain",
|
||||
@ -413,7 +395,6 @@ def _prompt_tune_cli(
|
||||
config=config,
|
||||
domain=domain,
|
||||
verbose=verbose,
|
||||
logger=logger,
|
||||
selection_method=selection_method,
|
||||
limit=limit,
|
||||
max_tokens=max_tokens,
|
||||
@ -453,6 +434,12 @@ def _query_cli(
|
||||
readable=True,
|
||||
autocompletion=CONFIG_AUTOCOMPLETE,
|
||||
),
|
||||
verbose: bool = typer.Option(
|
||||
False,
|
||||
"--verbose",
|
||||
"-v",
|
||||
help="Run the query with verbose logging.",
|
||||
),
|
||||
data: Path | None = typer.Option(
|
||||
None,
|
||||
"--data",
|
||||
@ -520,6 +507,7 @@ def _query_cli(
|
||||
response_type=response_type,
|
||||
streaming=streaming,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
case SearchMethod.GLOBAL:
|
||||
run_global_search(
|
||||
@ -531,6 +519,7 @@ def _query_cli(
|
||||
response_type=response_type,
|
||||
streaming=streaming,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
case SearchMethod.DRIFT:
|
||||
run_drift_search(
|
||||
@ -541,6 +530,7 @@ def _query_cli(
|
||||
streaming=streaming,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
case SearchMethod.BASIC:
|
||||
run_basic_search(
|
||||
@ -549,6 +539,7 @@ def _query_cli(
|
||||
root_dir=root,
|
||||
streaming=streaming,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(INVALID_METHOD_ERROR)
|
||||
|
||||
@ -3,13 +3,12 @@
|
||||
|
||||
"""CLI implementation of the prompt-tune subcommand."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import graphrag.api as api
|
||||
from graphrag.cli.index import _logger
|
||||
from graphrag.config.enums import ReportingType
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.config.logging import enable_logging_with_config
|
||||
from graphrag.logger.factory import LoggerFactory, LoggerType
|
||||
from graphrag.prompt_tune.generator.community_report_summarization import (
|
||||
COMMUNITY_SUMMARIZATION_FILENAME,
|
||||
)
|
||||
@ -21,13 +20,14 @@ from graphrag.prompt_tune.generator.extract_graph_prompt import (
|
||||
)
|
||||
from graphrag.utils.cli import redact
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def prompt_tune(
|
||||
root: Path,
|
||||
config: Path | None,
|
||||
domain: str | None,
|
||||
verbose: bool,
|
||||
logger: LoggerType,
|
||||
selection_method: api.DocSelectionType,
|
||||
limit: int,
|
||||
max_tokens: int,
|
||||
@ -47,8 +47,7 @@ async def prompt_tune(
|
||||
- config: The configuration file.
|
||||
- root: The root directory.
|
||||
- domain: The domain to map the input documents to.
|
||||
- verbose: Whether to enable verbose logging.
|
||||
- logger: The logger to use.
|
||||
- verbose: Enable verbose logging.
|
||||
- selection_method: The chunk selection method.
|
||||
- limit: The limit of chunks to load.
|
||||
- max_tokens: The maximum number of tokens to use on entity extraction prompts.
|
||||
@ -70,23 +69,29 @@ async def prompt_tune(
|
||||
if overlap != graph_config.chunks.overlap:
|
||||
graph_config.chunks.overlap = overlap
|
||||
|
||||
progress_logger = LoggerFactory().create_logger(logger)
|
||||
info, error, success = _logger(progress_logger)
|
||||
# configure the root logger with the specified log level
|
||||
from graphrag.logger.standard_logging import init_loggers
|
||||
|
||||
enabled_logging, log_path = enable_logging_with_config(
|
||||
graph_config, verbose, filename="prompt-tune.log"
|
||||
# initialize loggers with config
|
||||
init_loggers(
|
||||
config=graph_config,
|
||||
root_dir=str(root_path),
|
||||
verbose=verbose,
|
||||
)
|
||||
if enabled_logging:
|
||||
info(f"Logging enabled at {log_path}", verbose)
|
||||
|
||||
# log the configuration details
|
||||
if graph_config.reporting.type == ReportingType.file:
|
||||
log_dir = Path(root_path) / (graph_config.reporting.base_dir or "")
|
||||
log_path = log_dir / "logs.txt"
|
||||
logger.info("Logging enabled at %s", log_path)
|
||||
else:
|
||||
info(
|
||||
f"Logging not enabled for config {redact(graph_config.model_dump())}",
|
||||
verbose,
|
||||
logger.info(
|
||||
"Logging not enabled for config %s",
|
||||
redact(graph_config.model_dump()),
|
||||
)
|
||||
|
||||
prompts = await api.generate_indexing_prompts(
|
||||
config=graph_config,
|
||||
logger=progress_logger,
|
||||
chunk_size=chunk_size,
|
||||
overlap=overlap,
|
||||
limit=limit,
|
||||
@ -102,20 +107,20 @@ async def prompt_tune(
|
||||
|
||||
output_path = output.resolve()
|
||||
if output_path:
|
||||
info(f"Writing prompts to {output_path}")
|
||||
logger.info("Writing prompts to %s", output_path)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
extract_graph_prompt_path = output_path / EXTRACT_GRAPH_FILENAME
|
||||
entity_summarization_prompt_path = output_path / ENTITY_SUMMARIZATION_FILENAME
|
||||
community_summarization_prompt_path = (
|
||||
output_path / COMMUNITY_SUMMARIZATION_FILENAME
|
||||
)
|
||||
# Write files to output path
|
||||
# write files to output path
|
||||
with extract_graph_prompt_path.open("wb") as file:
|
||||
file.write(prompts[0].encode(encoding="utf-8", errors="strict"))
|
||||
with entity_summarization_prompt_path.open("wb") as file:
|
||||
file.write(prompts[1].encode(encoding="utf-8", errors="strict"))
|
||||
with community_summarization_prompt_path.open("wb") as file:
|
||||
file.write(prompts[2].encode(encoding="utf-8", errors="strict"))
|
||||
success(f"Prompts written to {output_path}")
|
||||
logger.info("Prompts written to %s", output_path)
|
||||
else:
|
||||
error("No output path provided. Skipping writing prompts.")
|
||||
logger.error("No output path provided. Skipping writing prompts.")
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
"""CLI implementation of the query subcommand."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@ -12,14 +13,14 @@ import graphrag.api as api
|
||||
from graphrag.callbacks.noop_query_callbacks import NoopQueryCallbacks
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.logger.print_progress import PrintProgressLogger
|
||||
from graphrag.utils.api import create_storage_from_config
|
||||
from graphrag.utils.storage import load_table_from_storage, storage_has_table
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
logger = PrintProgressLogger("")
|
||||
# Initialize standard logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_global_search(
|
||||
@ -31,6 +32,7 @@ def run_global_search(
|
||||
response_type: str,
|
||||
streaming: bool,
|
||||
query: str,
|
||||
verbose: bool,
|
||||
):
|
||||
"""Perform a global search with a given query.
|
||||
|
||||
@ -59,10 +61,10 @@ def run_global_search(
|
||||
final_community_reports_list = dataframe_dict["community_reports"]
|
||||
index_names = dataframe_dict["index_names"]
|
||||
|
||||
logger.success(
|
||||
f"Running Multi-index Global Search: {dataframe_dict['index_names']}"
|
||||
logger.info(
|
||||
"Running multi-index global search on indexes: %s",
|
||||
dataframe_dict["index_names"],
|
||||
)
|
||||
|
||||
response, context_data = asyncio.run(
|
||||
api.multi_index_global_search(
|
||||
config=config,
|
||||
@ -75,9 +77,12 @@ def run_global_search(
|
||||
response_type=response_type,
|
||||
streaming=streaming,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
logger.success(f"Global Search Response:\n{response}")
|
||||
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
|
||||
logger.info("Query Response:\n%s", response)
|
||||
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
@ -110,6 +115,7 @@ def run_global_search(
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
callbacks=[callbacks],
|
||||
verbose=verbose,
|
||||
):
|
||||
full_response += stream_chunk
|
||||
print(stream_chunk, end="") # noqa: T201
|
||||
@ -129,9 +135,12 @@ def run_global_search(
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
logger.success(f"Global Search Response:\n{response}")
|
||||
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
|
||||
logger.info("Global Search Response:\n%s", response)
|
||||
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
@ -145,6 +154,7 @@ def run_local_search(
|
||||
response_type: str,
|
||||
streaming: bool,
|
||||
query: str,
|
||||
verbose: bool,
|
||||
):
|
||||
"""Perform a local search with a given query.
|
||||
|
||||
@ -178,8 +188,9 @@ def run_local_search(
|
||||
final_relationships_list = dataframe_dict["relationships"]
|
||||
index_names = dataframe_dict["index_names"]
|
||||
|
||||
logger.success(
|
||||
f"Running Multi-index Local Search: {dataframe_dict['index_names']}"
|
||||
logger.info(
|
||||
"Running multi-index local search on indexes: %s",
|
||||
dataframe_dict["index_names"],
|
||||
)
|
||||
|
||||
# If any covariates tables are missing from any index, set the covariates list to None
|
||||
@ -202,9 +213,12 @@ def run_local_search(
|
||||
response_type=response_type,
|
||||
streaming=streaming,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
logger.success(f"Local Search Response:\n{response}")
|
||||
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
|
||||
logger.info("Local Search Response:\n%s", response)
|
||||
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
@ -242,6 +256,7 @@ def run_local_search(
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
callbacks=[callbacks],
|
||||
verbose=verbose,
|
||||
):
|
||||
full_response += stream_chunk
|
||||
print(stream_chunk, end="") # noqa: T201
|
||||
@ -263,9 +278,12 @@ def run_local_search(
|
||||
community_level=community_level,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
logger.success(f"Local Search Response:\n{response}")
|
||||
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
|
||||
logger.info("Local Search Response:\n%s", response)
|
||||
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
@ -279,6 +297,7 @@ def run_drift_search(
|
||||
response_type: str,
|
||||
streaming: bool,
|
||||
query: str,
|
||||
verbose: bool,
|
||||
):
|
||||
"""Perform a local search with a given query.
|
||||
|
||||
@ -310,8 +329,9 @@ def run_drift_search(
|
||||
final_relationships_list = dataframe_dict["relationships"]
|
||||
index_names = dataframe_dict["index_names"]
|
||||
|
||||
logger.success(
|
||||
f"Running Multi-index Drift Search: {dataframe_dict['index_names']}"
|
||||
logger.info(
|
||||
"Running multi-index drift search on indexes: %s",
|
||||
dataframe_dict["index_names"],
|
||||
)
|
||||
|
||||
response, context_data = asyncio.run(
|
||||
@ -327,9 +347,12 @@ def run_drift_search(
|
||||
response_type=response_type,
|
||||
streaming=streaming,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
logger.success(f"DRIFT Search Response:\n{response}")
|
||||
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
|
||||
logger.info("DRIFT Search Response:\n%s", response)
|
||||
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
@ -365,6 +388,7 @@ def run_drift_search(
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
callbacks=[callbacks],
|
||||
verbose=verbose,
|
||||
):
|
||||
full_response += stream_chunk
|
||||
print(stream_chunk, end="") # noqa: T201
|
||||
@ -386,9 +410,12 @@ def run_drift_search(
|
||||
community_level=community_level,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
logger.success(f"DRIFT Search Response:\n{response}")
|
||||
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
|
||||
logger.info("DRIFT Search Response:\n%s", response)
|
||||
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
@ -400,6 +427,7 @@ def run_basic_search(
|
||||
root_dir: Path,
|
||||
streaming: bool,
|
||||
query: str,
|
||||
verbose: bool,
|
||||
):
|
||||
"""Perform a basics search with a given query.
|
||||
|
||||
@ -423,8 +451,9 @@ def run_basic_search(
|
||||
final_text_units_list = dataframe_dict["text_units"]
|
||||
index_names = dataframe_dict["index_names"]
|
||||
|
||||
logger.success(
|
||||
f"Running Multi-index Basic Search: {dataframe_dict['index_names']}"
|
||||
logger.info(
|
||||
"Running multi-index basic search on indexes: %s",
|
||||
dataframe_dict["index_names"],
|
||||
)
|
||||
|
||||
response, context_data = asyncio.run(
|
||||
@ -434,9 +463,12 @@ def run_basic_search(
|
||||
index_names=index_names,
|
||||
streaming=streaming,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
logger.success(f"Basic Search Response:\n{response}")
|
||||
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
|
||||
logger.info("Basic Search Response:\n%s", response)
|
||||
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
@ -461,6 +493,8 @@ def run_basic_search(
|
||||
config=config,
|
||||
text_units=final_text_units,
|
||||
query=query,
|
||||
callbacks=[callbacks],
|
||||
verbose=verbose,
|
||||
):
|
||||
full_response += stream_chunk
|
||||
print(stream_chunk, end="") # noqa: T201
|
||||
@ -475,9 +509,12 @@ def run_basic_search(
|
||||
config=config,
|
||||
text_units=final_text_units,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
logger.success(f"Basic Search Response:\n{response}")
|
||||
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
|
||||
logger.info("Basic Search Response:\n%s", response)
|
||||
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
from graphrag.config.embeddings import default_embeddings
|
||||
from graphrag.config.enums import (
|
||||
@ -54,7 +54,7 @@ class BasicSearchDefaults:
|
||||
class CacheDefaults:
|
||||
"""Default values for cache."""
|
||||
|
||||
type = CacheType.file
|
||||
type: ClassVar[CacheType] = CacheType.file
|
||||
base_dir: str = "cache"
|
||||
connection_string: None = None
|
||||
container_name: None = None
|
||||
@ -69,7 +69,7 @@ class ChunksDefaults:
|
||||
size: int = 1200
|
||||
overlap: int = 100
|
||||
group_by_columns: list[str] = field(default_factory=lambda: ["id"])
|
||||
strategy = ChunkStrategyType.tokens
|
||||
strategy: ClassVar[ChunkStrategyType] = ChunkStrategyType.tokens
|
||||
encoding_model: str = "cl100k_base"
|
||||
prepend_metadata: bool = False
|
||||
chunk_size_includes_metadata: bool = False
|
||||
@ -119,8 +119,8 @@ class DriftSearchDefaults:
|
||||
local_search_temperature: float = 0
|
||||
local_search_top_p: float = 1
|
||||
local_search_n: int = 1
|
||||
local_search_llm_max_gen_tokens = None
|
||||
local_search_llm_max_gen_completion_tokens = None
|
||||
local_search_llm_max_gen_tokens: int | None = None
|
||||
local_search_llm_max_gen_completion_tokens: int | None = None
|
||||
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
|
||||
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
|
||||
|
||||
@ -183,7 +183,9 @@ class ExtractGraphDefaults:
|
||||
class TextAnalyzerDefaults:
|
||||
"""Default values for text analyzer."""
|
||||
|
||||
extractor_type = NounPhraseExtractorType.RegexEnglish
|
||||
extractor_type: ClassVar[NounPhraseExtractorType] = (
|
||||
NounPhraseExtractorType.RegexEnglish
|
||||
)
|
||||
model_name: str = "en_core_web_md"
|
||||
max_word_length: int = 15
|
||||
word_delimiter: str = " "
|
||||
@ -257,7 +259,7 @@ class InputDefaults:
|
||||
"""Default values for input."""
|
||||
|
||||
storage: InputStorageDefaults = field(default_factory=InputStorageDefaults)
|
||||
file_type = InputFileType.text
|
||||
file_type: ClassVar[InputFileType] = InputFileType.text
|
||||
encoding: str = "utf-8"
|
||||
file_pattern: str = ""
|
||||
file_filter: None = None
|
||||
@ -271,7 +273,7 @@ class LanguageModelDefaults:
|
||||
"""Default values for language model."""
|
||||
|
||||
api_key: None = None
|
||||
auth_type = AuthType.APIKey
|
||||
auth_type: ClassVar[AuthType] = AuthType.APIKey
|
||||
encoding_model: str = ""
|
||||
max_tokens: int | None = None
|
||||
temperature: float = 0
|
||||
@ -338,7 +340,7 @@ class PruneGraphDefaults:
|
||||
class ReportingDefaults:
|
||||
"""Default values for reporting."""
|
||||
|
||||
type = ReportingType.file
|
||||
type: ClassVar[ReportingType] = ReportingType.file
|
||||
base_dir: str = "logs"
|
||||
connection_string: None = None
|
||||
container_name: None = None
|
||||
@ -383,7 +385,7 @@ class UpdateIndexOutputDefaults(StorageDefaults):
|
||||
class VectorStoreDefaults:
|
||||
"""Default values for vector stores."""
|
||||
|
||||
type = VectorStoreType.LanceDB.value
|
||||
type: ClassVar[str] = VectorStoreType.LanceDB.value
|
||||
db_uri: str = str(Path(DEFAULT_OUTPUT_BASE_DIR) / "lancedb")
|
||||
container_name: str = "default"
|
||||
overwrite: bool = True
|
||||
|
||||
@ -64,8 +64,6 @@ class ReportingType(str, Enum):
|
||||
|
||||
file = "file"
|
||||
"""The file reporting configuration type."""
|
||||
console = "console"
|
||||
"""The console reporting configuration type."""
|
||||
blob = "blob"
|
||||
"""The blob reporting configuration type."""
|
||||
|
||||
|
||||
@ -1,61 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Logging utilities. A unified way for enabling logging."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.config.enums import ReportingType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
|
||||
|
||||
def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None:
|
||||
"""Enable logging to a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
log_filepath : str | Path
|
||||
The path to the log file.
|
||||
verbose : bool, default=False
|
||||
Whether to log debug messages.
|
||||
"""
|
||||
log_filepath = Path(log_filepath)
|
||||
log_filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
log_filepath.touch(exist_ok=True)
|
||||
|
||||
logging.basicConfig(
|
||||
filename=log_filepath,
|
||||
filemode="a",
|
||||
format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
level=logging.DEBUG if verbose else logging.INFO,
|
||||
)
|
||||
|
||||
|
||||
def enable_logging_with_config(
|
||||
config: GraphRagConfig, verbose: bool = False, filename: str = "indexing-engine.log"
|
||||
) -> tuple[bool, str]:
|
||||
"""Enable logging to a file based on the config.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : GraphRagConfig
|
||||
The configuration.
|
||||
timestamp_value : str
|
||||
The timestamp value representing the directory to place the log files.
|
||||
verbose : bool, default=False
|
||||
Whether to log debug messages.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[bool, str]
|
||||
A tuple of a boolean indicating if logging was enabled and the path to the log file.
|
||||
(False, "") if logging was not enabled.
|
||||
(True, str) if logging was enabled.
|
||||
"""
|
||||
if config.reporting.type == ReportingType.file:
|
||||
log_path = Path(config.reporting.base_dir) / filename
|
||||
enable_logging(log_path, verbose)
|
||||
return (True, str(log_path))
|
||||
return (False, "")
|
||||
@ -9,17 +9,17 @@ from pathlib import Path
|
||||
|
||||
from dotenv import dotenv_values
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def read_dotenv(root: str) -> None:
|
||||
"""Read a .env file in the given root path."""
|
||||
env_path = Path(root) / ".env"
|
||||
if env_path.exists():
|
||||
log.info("Loading pipeline .env file")
|
||||
logger.info("Loading pipeline .env file")
|
||||
env_config = dotenv_values(f"{env_path}")
|
||||
for key, value in env_config.items():
|
||||
if key not in os.environ:
|
||||
os.environ[key] = value or ""
|
||||
else:
|
||||
log.info("No .env file found at %s", root)
|
||||
logger.info("No .env file found at %s", root)
|
||||
|
||||
@ -10,19 +10,17 @@ import pandas as pd
|
||||
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.util import load_files, process_data_columns
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def load_csv(
|
||||
config: InputConfig,
|
||||
progress: ProgressLogger | None,
|
||||
storage: PipelineStorage,
|
||||
) -> pd.DataFrame:
|
||||
"""Load csv inputs from a directory."""
|
||||
log.info("Loading csv files from %s", config.storage.base_dir)
|
||||
logger.info("Loading csv files from %s", config.storage.base_dir)
|
||||
|
||||
async def load_file(path: str, group: dict | None) -> pd.DataFrame:
|
||||
if group is None:
|
||||
@ -42,4 +40,4 @@ async def load_csv(
|
||||
|
||||
return data
|
||||
|
||||
return await load_files(load_file, config, storage, progress)
|
||||
return await load_files(load_file, config, storage)
|
||||
|
||||
@ -14,11 +14,9 @@ from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.csv import load_csv
|
||||
from graphrag.index.input.json import load_json
|
||||
from graphrag.index.input.text import load_text
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.null_progress import NullProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
|
||||
InputFileType.text: load_text,
|
||||
InputFileType.csv: load_csv,
|
||||
@ -29,17 +27,14 @@ loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
|
||||
async def create_input(
|
||||
config: InputConfig,
|
||||
storage: PipelineStorage,
|
||||
progress_reporter: ProgressLogger | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Instantiate input data for a pipeline."""
|
||||
progress_reporter = progress_reporter or NullProgressLogger()
|
||||
logger.info("loading input from root_dir=%s", config.storage.base_dir)
|
||||
|
||||
if config.file_type in loaders:
|
||||
progress = progress_reporter.child(
|
||||
f"Loading Input ({config.file_type})", transient=False
|
||||
)
|
||||
logger.info("Loading Input %s", config.file_type)
|
||||
loader = loaders[config.file_type]
|
||||
result = await loader(config, progress, storage)
|
||||
result = await loader(config, storage)
|
||||
# Convert metadata columns to strings and collapse them into a JSON object
|
||||
if config.metadata:
|
||||
if all(col in result.columns for col in config.metadata):
|
||||
|
||||
@ -10,19 +10,17 @@ import pandas as pd
|
||||
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.util import load_files, process_data_columns
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def load_json(
|
||||
config: InputConfig,
|
||||
progress: ProgressLogger | None,
|
||||
storage: PipelineStorage,
|
||||
) -> pd.DataFrame:
|
||||
"""Load json inputs from a directory."""
|
||||
log.info("Loading json files from %s", config.storage.base_dir)
|
||||
logger.info("Loading json files from %s", config.storage.base_dir)
|
||||
|
||||
async def load_file(path: str, group: dict | None) -> pd.DataFrame:
|
||||
if group is None:
|
||||
@ -46,4 +44,4 @@ async def load_json(
|
||||
|
||||
return data
|
||||
|
||||
return await load_files(load_file, config, storage, progress)
|
||||
return await load_files(load_file, config, storage)
|
||||
|
||||
@ -11,15 +11,13 @@ import pandas as pd
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.util import load_files
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def load_text(
|
||||
config: InputConfig,
|
||||
progress: ProgressLogger | None,
|
||||
storage: PipelineStorage,
|
||||
) -> pd.DataFrame:
|
||||
"""Load text inputs from a directory."""
|
||||
@ -34,4 +32,4 @@ async def load_text(
|
||||
new_item["creation_date"] = await storage.get_creation_date(path)
|
||||
return pd.DataFrame([new_item])
|
||||
|
||||
return await load_files(load_file, config, storage, progress)
|
||||
return await load_files(load_file, config, storage)
|
||||
|
||||
@ -11,23 +11,20 @@ import pandas as pd
|
||||
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def load_files(
|
||||
loader: Any,
|
||||
config: InputConfig,
|
||||
storage: PipelineStorage,
|
||||
progress: ProgressLogger | None,
|
||||
) -> pd.DataFrame:
|
||||
"""Load files from storage and apply a loader function."""
|
||||
files = list(
|
||||
storage.find(
|
||||
re.compile(config.file_pattern),
|
||||
progress=progress,
|
||||
file_filter=config.file_filter,
|
||||
)
|
||||
)
|
||||
@ -42,17 +39,17 @@ async def load_files(
|
||||
try:
|
||||
files_loaded.append(await loader(file, group))
|
||||
except Exception as e: # noqa: BLE001 (catching Exception is fine here)
|
||||
log.warning("Warning! Error loading file %s. Skipping...", file)
|
||||
log.warning("Error: %s", e)
|
||||
logger.warning("Warning! Error loading file %s. Skipping...", file)
|
||||
logger.warning("Error: %s", e)
|
||||
|
||||
log.info(
|
||||
logger.info(
|
||||
"Found %d %s files, loading %d", len(files), config.file_type, len(files_loaded)
|
||||
)
|
||||
result = pd.concat(files_loaded)
|
||||
total_files_log = (
|
||||
f"Total number of unfiltered {config.file_type} rows: {len(result)}"
|
||||
)
|
||||
log.info(total_files_log)
|
||||
logger.info(total_files_log)
|
||||
return result
|
||||
|
||||
|
||||
@ -66,7 +63,7 @@ def process_data_columns(
|
||||
)
|
||||
if config.text_column is not None and "text" not in documents.columns:
|
||||
if config.text_column not in documents.columns:
|
||||
log.warning(
|
||||
logger.warning(
|
||||
"text_column %s not found in csv file %s",
|
||||
config.text_column,
|
||||
path,
|
||||
@ -75,7 +72,7 @@ def process_data_columns(
|
||||
documents["text"] = documents.apply(lambda x: x[config.text_column], axis=1)
|
||||
if config.title_column is not None:
|
||||
if config.title_column not in documents.columns:
|
||||
log.warning(
|
||||
logger.warning(
|
||||
"title_column %s not found in csv file %s",
|
||||
config.title_column,
|
||||
path,
|
||||
|
||||
@ -65,6 +65,7 @@ async def _extract_nodes(
|
||||
extract,
|
||||
num_threads=num_threads,
|
||||
async_type=AsyncType.Threaded,
|
||||
progress_msg="extract noun phrases progress: ",
|
||||
)
|
||||
|
||||
noun_node_df = text_unit_df.explode("noun_phrases")
|
||||
|
||||
@ -8,7 +8,7 @@ from abc import ABCMeta, abstractmethod
|
||||
|
||||
import spacy
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseNounPhraseExtractor(metaclass=ABCMeta):
|
||||
@ -54,7 +54,7 @@ class BaseNounPhraseExtractor(metaclass=ABCMeta):
|
||||
return spacy.load(model_name, exclude=exclude)
|
||||
except OSError:
|
||||
msg = f"Model `{model_name}` not found. Attempting to download..."
|
||||
log.info(msg)
|
||||
logger.info(msg)
|
||||
from spacy.cli.download import download
|
||||
|
||||
download(model_name)
|
||||
|
||||
@ -13,7 +13,7 @@ from graphrag.index.utils.stable_lcc import stable_largest_connected_component
|
||||
Communities = list[tuple[int, int, int, list[str]]]
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cluster_graph(
|
||||
@ -24,7 +24,7 @@ def cluster_graph(
|
||||
) -> Communities:
|
||||
"""Apply a hierarchical clustering algorithm to a graph."""
|
||||
if len(graph.nodes) == 0:
|
||||
log.warning("Graph has no nodes")
|
||||
logger.warning("Graph has no nodes")
|
||||
return []
|
||||
|
||||
node_id_to_community_map, parent_mapping = _compute_leiden_communities(
|
||||
|
||||
@ -17,7 +17,7 @@ from graphrag.index.operations.embed_text.strategies.typing import TextEmbedding
|
||||
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
|
||||
from graphrag.vector_stores.factory import VectorStoreFactory
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Per Azure OpenAI Limits
|
||||
# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference
|
||||
@ -141,7 +141,14 @@ async def _text_embed_with_vector_store(
|
||||
|
||||
all_results = []
|
||||
|
||||
num_total_batches = (input.shape[0] + insert_batch_size - 1) // insert_batch_size
|
||||
while insert_batch_size * i < input.shape[0]:
|
||||
logger.info(
|
||||
"uploading text embeddings batch %d/%d of size %d to vector store",
|
||||
i + 1,
|
||||
num_total_batches,
|
||||
insert_batch_size,
|
||||
)
|
||||
batch = input.iloc[insert_batch_size * i : insert_batch_size * (i + 1)]
|
||||
texts: list[str] = batch[embed_column].to_numpy().tolist()
|
||||
titles: list[str] = batch[title].to_numpy().tolist()
|
||||
@ -195,7 +202,7 @@ def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str:
|
||||
collection_name = create_collection_name(container_name, embedding_name)
|
||||
|
||||
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}"
|
||||
log.info(msg)
|
||||
logger.info(msg)
|
||||
return collection_name
|
||||
|
||||
|
||||
|
||||
@ -21,7 +21,9 @@ async def run( # noqa RUF029 async is required for interface
|
||||
) -> TextEmbeddingResult:
|
||||
"""Run the Claim extraction chain."""
|
||||
input = input if isinstance(input, Iterable) else [input]
|
||||
ticker = progress_ticker(callbacks.progress, len(input))
|
||||
ticker = progress_ticker(
|
||||
callbacks.progress, len(input), description="generate embeddings progress: "
|
||||
)
|
||||
return TextEmbeddingResult(
|
||||
embeddings=[_embed_text(cache, text, ticker) for text in input]
|
||||
)
|
||||
|
||||
@ -19,7 +19,7 @@ from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.protocol.base import EmbeddingModel
|
||||
from graphrag.logger.progress import ProgressTicker, progress_ticker
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run(
|
||||
@ -54,7 +54,7 @@ async def run(
|
||||
batch_max_tokens,
|
||||
splitter,
|
||||
)
|
||||
log.info(
|
||||
logger.info(
|
||||
"embedding %d inputs via %d snippets using %d batches. max_batch_size=%d, batch_max_tokens=%d",
|
||||
len(input),
|
||||
len(texts),
|
||||
@ -62,7 +62,11 @@ async def run(
|
||||
batch_size,
|
||||
batch_max_tokens,
|
||||
)
|
||||
ticker = progress_ticker(callbacks.progress, len(text_batches))
|
||||
ticker = progress_ticker(
|
||||
callbacks.progress,
|
||||
len(text_batches),
|
||||
description="generate embeddings progress: ",
|
||||
)
|
||||
|
||||
# Embed each chunk of snippets
|
||||
embeddings = await _execute(model, text_batches, ticker, semaphore)
|
||||
|
||||
@ -20,7 +20,7 @@ from graphrag.prompts.index.extract_claims import (
|
||||
DEFAULT_TUPLE_DELIMITER = "<|>"
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -119,7 +119,7 @@ class ClaimExtractor:
|
||||
]
|
||||
source_doc_map[document_id] = text
|
||||
except Exception as e:
|
||||
log.exception("error extracting claim")
|
||||
logger.exception("error extracting claim")
|
||||
self._on_error(
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
|
||||
@ -23,7 +23,7 @@ from graphrag.index.operations.extract_covariates.typing import (
|
||||
from graphrag.index.utils.derive_from_rows import derive_from_rows
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
|
||||
@ -41,7 +41,7 @@ async def extract_covariates(
|
||||
num_threads: int = 4,
|
||||
):
|
||||
"""Extract claims from a piece of text."""
|
||||
log.debug("extract_covariates strategy=%s", strategy)
|
||||
logger.debug("extract_covariates strategy=%s", strategy)
|
||||
if entity_types is None:
|
||||
entity_types = DEFAULT_ENTITY_TYPES
|
||||
|
||||
@ -71,6 +71,7 @@ async def extract_covariates(
|
||||
callbacks,
|
||||
async_type=async_mode,
|
||||
num_threads=num_threads,
|
||||
progress_msg="extract covariates progress: ",
|
||||
)
|
||||
return pd.DataFrame([item for row in results for item in row or []])
|
||||
|
||||
@ -110,8 +111,8 @@ async def run_extract_claims(
|
||||
model_invoker=llm,
|
||||
extraction_prompt=extraction_prompt,
|
||||
max_gleanings=max_gleanings,
|
||||
on_error=lambda e, s, d: (
|
||||
callbacks.error("Claim Extraction Error", e, s, d) if callbacks else None
|
||||
on_error=lambda e, s, d: logger.error(
|
||||
"Claim Extraction Error", exc_info=e, extra={"stack": s, "details": d}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from graphrag.index.operations.extract_graph.typing import (
|
||||
)
|
||||
from graphrag.index.utils.derive_from_rows import derive_from_rows
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
|
||||
@ -36,7 +36,7 @@ async def extract_graph(
|
||||
num_threads: int = 4,
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""Extract a graph from a piece of text using a language model."""
|
||||
log.debug("entity_extract strategy=%s", strategy)
|
||||
logger.debug("entity_extract strategy=%s", strategy)
|
||||
if entity_types is None:
|
||||
entity_types = DEFAULT_ENTITY_TYPES
|
||||
strategy = strategy or {}
|
||||
@ -54,7 +54,6 @@ async def extract_graph(
|
||||
result = await strategy_exec(
|
||||
[Document(text=text, id=id)],
|
||||
entity_types,
|
||||
callbacks,
|
||||
cache,
|
||||
strategy_config,
|
||||
)
|
||||
@ -67,6 +66,7 @@ async def extract_graph(
|
||||
callbacks,
|
||||
async_type=async_mode,
|
||||
num_threads=num_threads,
|
||||
progress_msg="extract graph progress: ",
|
||||
)
|
||||
|
||||
entity_dfs = []
|
||||
|
||||
@ -27,7 +27,7 @@ DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
|
||||
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -119,7 +119,7 @@ class GraphExtractor:
|
||||
source_doc_map[doc_index] = text
|
||||
all_records[doc_index] = result
|
||||
except Exception as e:
|
||||
log.exception("error extracting graph")
|
||||
logger.exception("error extracting graph")
|
||||
self._on_error(
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
|
||||
@ -3,10 +3,11 @@
|
||||
|
||||
"""A module containing run_graph_intelligence, run_extract_graph and _create_text_splitter methods to run graph intelligence."""
|
||||
|
||||
import logging
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.index.operations.extract_graph.graph_extractor import GraphExtractor
|
||||
@ -19,11 +20,12 @@ from graphrag.index.operations.extract_graph.typing import (
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_graph_intelligence(
|
||||
docs: list[Document],
|
||||
entity_types: EntityTypes,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
args: StrategyConfig,
|
||||
) -> EntityExtractionResult:
|
||||
@ -34,18 +36,16 @@ async def run_graph_intelligence(
|
||||
name="extract_graph",
|
||||
model_type=llm_config.type,
|
||||
config=llm_config,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
return await run_extract_graph(llm, docs, entity_types, callbacks, args)
|
||||
return await run_extract_graph(llm, docs, entity_types, args)
|
||||
|
||||
|
||||
async def run_extract_graph(
|
||||
model: ChatModel,
|
||||
docs: list[Document],
|
||||
entity_types: EntityTypes,
|
||||
callbacks: WorkflowCallbacks | None,
|
||||
args: StrategyConfig,
|
||||
) -> EntityExtractionResult:
|
||||
"""Run the entity extraction chain."""
|
||||
@ -61,8 +61,8 @@ async def run_extract_graph(
|
||||
model_invoker=model,
|
||||
prompt=extraction_prompt,
|
||||
max_gleanings=max_gleanings,
|
||||
on_error=lambda e, s, d: (
|
||||
callbacks.error("Entity Extraction Error", e, s, d) if callbacks else None
|
||||
on_error=lambda e, s, d: logger.error(
|
||||
"Entity Extraction Error", exc_info=e, extra={"stack": s, "details": d}
|
||||
),
|
||||
)
|
||||
text_list = [doc.text.strip() for doc in docs]
|
||||
|
||||
@ -11,7 +11,6 @@ from typing import Any
|
||||
import networkx as nx
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
|
||||
ExtractedEntity = dict[str, Any]
|
||||
ExtractedRelationship = dict[str, Any]
|
||||
@ -40,7 +39,6 @@ EntityExtractStrategy = Callable[
|
||||
[
|
||||
list[Document],
|
||||
EntityTypes,
|
||||
WorkflowCallbacks,
|
||||
PipelineCache,
|
||||
StrategyConfig,
|
||||
],
|
||||
|
||||
@ -7,7 +7,6 @@ from uuid import uuid4
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
|
||||
from graphrag.data_model.schemas import ENTITIES_FINAL_COLUMNS
|
||||
from graphrag.index.operations.compute_degree import compute_degree
|
||||
@ -19,7 +18,6 @@ from graphrag.index.operations.layout_graph.layout_graph import layout_graph
|
||||
def finalize_entities(
|
||||
entities: pd.DataFrame,
|
||||
relationships: pd.DataFrame,
|
||||
callbacks: WorkflowCallbacks,
|
||||
embed_config: EmbedGraphConfig | None = None,
|
||||
layout_enabled: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
@ -33,7 +31,6 @@ def finalize_entities(
|
||||
)
|
||||
layout = layout_graph(
|
||||
graph,
|
||||
callbacks,
|
||||
layout_enabled,
|
||||
embeddings=graph_embeddings,
|
||||
)
|
||||
|
||||
@ -3,17 +3,19 @@
|
||||
|
||||
"""A module containing layout_graph, _run_layout and _apply_layout_to_graph methods definition."""
|
||||
|
||||
import logging
|
||||
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.operations.embed_graph.typing import NodeEmbeddings
|
||||
from graphrag.index.operations.layout_graph.typing import GraphLayout
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def layout_graph(
|
||||
graph: nx.Graph,
|
||||
callbacks: WorkflowCallbacks,
|
||||
enabled: bool,
|
||||
embeddings: NodeEmbeddings | None,
|
||||
):
|
||||
@ -44,7 +46,6 @@ def layout_graph(
|
||||
graph,
|
||||
enabled,
|
||||
embeddings if embeddings is not None else {},
|
||||
callbacks,
|
||||
)
|
||||
|
||||
layout_df = pd.DataFrame(layout)
|
||||
@ -58,7 +59,6 @@ def _run_layout(
|
||||
graph: nx.Graph,
|
||||
enabled: bool,
|
||||
embeddings: NodeEmbeddings,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> GraphLayout:
|
||||
if enabled:
|
||||
from graphrag.index.operations.layout_graph.umap import (
|
||||
@ -68,7 +68,9 @@ def _run_layout(
|
||||
return run_umap(
|
||||
graph,
|
||||
embeddings,
|
||||
lambda e, stack, d: callbacks.error("Error in Umap", e, stack, d),
|
||||
lambda e, stack, d: logger.error(
|
||||
"Error in Umap", exc_info=e, extra={"stack": stack, "details": d}
|
||||
),
|
||||
)
|
||||
from graphrag.index.operations.layout_graph.zero import (
|
||||
run as run_zero,
|
||||
@ -76,5 +78,7 @@ def _run_layout(
|
||||
|
||||
return run_zero(
|
||||
graph,
|
||||
lambda e, stack, d: callbacks.error("Error in Zero", e, stack, d),
|
||||
lambda e, stack, d: logger.error(
|
||||
"Error in Zero", exc_info=e, extra={"stack": stack, "details": d}
|
||||
),
|
||||
)
|
||||
|
||||
@ -20,7 +20,7 @@ from graphrag.index.typing.error_handler import ErrorHandlerFn
|
||||
# for "size" or "cluster"
|
||||
# We could also have a boolean to indicate to use node sizes or clusters
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run(
|
||||
@ -56,7 +56,7 @@ def run(
|
||||
**additional_args,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception("Error running UMAP")
|
||||
logger.exception("Error running UMAP")
|
||||
on_error(e, traceback.format_exc(), None)
|
||||
# Umap may fail due to input sparseness or memory pressure.
|
||||
# For now, in these cases, we'll just return a layout with all nodes at (0, 0)
|
||||
|
||||
@ -18,7 +18,7 @@ from graphrag.index.typing.error_handler import ErrorHandlerFn
|
||||
# for "size" or "cluster"
|
||||
# We could also have a boolean to indicate to use node sizes or clusters
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run(
|
||||
@ -47,7 +47,7 @@ def run(
|
||||
try:
|
||||
return get_zero_positions(node_labels=nodes, **additional_args)
|
||||
except Exception as e:
|
||||
log.exception("Error running zero-position")
|
||||
logger.exception("Error running zero-position")
|
||||
on_error(e, traceback.format_exc(), None)
|
||||
# Umap may fail due to input sparseness or memory pressure.
|
||||
# For now, in these cases, we'll just return a layout with all nodes at (0, 0)
|
||||
|
||||
@ -13,7 +13,7 @@ from graphrag.index.typing.error_handler import ErrorHandlerFn
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# these tokens are used in the prompt
|
||||
INPUT_TEXT_KEY = "input_text"
|
||||
@ -86,7 +86,7 @@ class CommunityReportsExtractor:
|
||||
|
||||
output = response.parsed_response
|
||||
except Exception as e:
|
||||
log.exception("error generating community report")
|
||||
logger.exception("error generating community report")
|
||||
self._on_error(e, traceback.format_exc(), None)
|
||||
|
||||
text_output = self._get_text_output(output) if output else ""
|
||||
|
||||
@ -32,7 +32,7 @@ from graphrag.index.utils.dataframes import (
|
||||
from graphrag.logger.progress import progress_iterable
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_local_context(
|
||||
@ -69,7 +69,7 @@ def _prepare_reports_at_level(
|
||||
"""Prepare reports at a given level."""
|
||||
# Filter and prepare node details
|
||||
level_node_df = node_df[node_df[schemas.COMMUNITY_LEVEL] == level]
|
||||
log.info("Number of nodes at level=%s => %s", level, len(level_node_df))
|
||||
logger.info("Number of nodes at level=%s => %s", level, len(level_node_df))
|
||||
nodes_set = set(level_node_df[schemas.TITLE])
|
||||
|
||||
# Filter and prepare edge details
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
"""A module containing run, _run_extractor and _load_nodes_edges_for_claim_chain methods definition."""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
@ -21,7 +20,7 @@ from graphrag.index.utils.rate_limiter import RateLimiter
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_graph_intelligence(
|
||||
@ -42,7 +41,7 @@ async def run_graph_intelligence(
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
return await _run_extractor(llm, community, input, level, args, callbacks)
|
||||
return await _run_extractor(llm, community, input, level, args)
|
||||
|
||||
|
||||
async def _run_extractor(
|
||||
@ -51,7 +50,6 @@ async def _run_extractor(
|
||||
input: str,
|
||||
level: int,
|
||||
args: StrategyConfig,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> CommunityReport | None:
|
||||
# RateLimiter
|
||||
rate_limiter = RateLimiter(rate=1, per=60)
|
||||
@ -59,8 +57,8 @@ async def _run_extractor(
|
||||
model,
|
||||
extraction_prompt=args.get("extraction_prompt", None),
|
||||
max_report_length=args.get("max_report_length", None),
|
||||
on_error=lambda e, stack, _data: callbacks.error(
|
||||
"Community Report Extraction Error", e, stack
|
||||
on_error=lambda e, stack, _data: logger.error(
|
||||
"Community Report Extraction Error", exc_info=e, extra={"stack": stack}
|
||||
),
|
||||
)
|
||||
|
||||
@ -69,7 +67,7 @@ async def _run_extractor(
|
||||
results = await extractor(input)
|
||||
report = results.structured_output
|
||||
if report is None:
|
||||
log.warning("No report found for community: %s", community)
|
||||
logger.warning("No report found for community: %s", community)
|
||||
return None
|
||||
|
||||
return CommunityReport(
|
||||
@ -86,7 +84,6 @@ async def _run_extractor(
|
||||
],
|
||||
full_content_json=report.model_dump_json(indent=4),
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception("Error processing community: %s", community)
|
||||
callbacks.error("Community Report Extraction Error", e, traceback.format_exc())
|
||||
except Exception:
|
||||
logger.exception("Error processing community: %s", community)
|
||||
return None
|
||||
|
||||
@ -24,7 +24,7 @@ from graphrag.index.operations.summarize_communities.utils import (
|
||||
from graphrag.index.utils.derive_from_rows import derive_from_rows
|
||||
from graphrag.logger.progress import progress_ticker
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def summarize_communities(
|
||||
@ -64,7 +64,7 @@ async def summarize_communities(
|
||||
)
|
||||
level_contexts.append(level_context)
|
||||
|
||||
for level_context in level_contexts:
|
||||
for i, level_context in enumerate(level_contexts):
|
||||
|
||||
async def run_generate(record):
|
||||
result = await _generate_report(
|
||||
@ -85,6 +85,7 @@ async def summarize_communities(
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
num_threads=num_threads,
|
||||
async_type=async_mode,
|
||||
progress_msg=f"level {levels[i]} summarize communities progress: ",
|
||||
)
|
||||
reports.extend([lr for lr in local_reports if lr is not None])
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ from graphrag.index.operations.summarize_communities.text_unit_context.sort_cont
|
||||
)
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_local_context(
|
||||
|
||||
@ -9,7 +9,7 @@ import pandas as pd
|
||||
|
||||
import graphrag.data_model.schemas as schemas
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prep_text_units(
|
||||
|
||||
@ -10,7 +10,7 @@ import pandas as pd
|
||||
import graphrag.data_model.schemas as schemas
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_context_string(
|
||||
|
||||
@ -3,8 +3,9 @@
|
||||
|
||||
"""A module containing run_graph_intelligence, run_resolve_entities and _create_text_list_splitter methods to run graph intelligence."""
|
||||
|
||||
import logging
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.index.operations.summarize_descriptions.description_summary_extractor import (
|
||||
SummarizeExtractor,
|
||||
@ -16,11 +17,12 @@ from graphrag.index.operations.summarize_descriptions.typing import (
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_graph_intelligence(
|
||||
id: str | tuple[str, str],
|
||||
descriptions: list[str],
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
args: StrategyConfig,
|
||||
) -> SummarizedDescriptionResult:
|
||||
@ -30,18 +32,16 @@ async def run_graph_intelligence(
|
||||
name="summarize_descriptions",
|
||||
model_type=llm_config.type,
|
||||
config=llm_config,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
return await run_summarize_descriptions(llm, id, descriptions, callbacks, args)
|
||||
return await run_summarize_descriptions(llm, id, descriptions, args)
|
||||
|
||||
|
||||
async def run_summarize_descriptions(
|
||||
model: ChatModel,
|
||||
id: str | tuple[str, str],
|
||||
descriptions: list[str],
|
||||
callbacks: WorkflowCallbacks,
|
||||
args: StrategyConfig,
|
||||
) -> SummarizedDescriptionResult:
|
||||
"""Run the entity extraction chain."""
|
||||
@ -52,10 +52,10 @@ async def run_summarize_descriptions(
|
||||
extractor = SummarizeExtractor(
|
||||
model_invoker=model,
|
||||
summarization_prompt=summarize_prompt,
|
||||
on_error=lambda e, stack, details: (
|
||||
callbacks.error("Entity Extraction Error", e, stack, details)
|
||||
if callbacks
|
||||
else None
|
||||
on_error=lambda e, stack, details: logger.error(
|
||||
"Entity Extraction Error",
|
||||
exc_info=e,
|
||||
extra={"stack": stack, "details": details},
|
||||
),
|
||||
max_summary_length=max_summary_length,
|
||||
max_input_tokens=max_input_tokens,
|
||||
|
||||
@ -17,7 +17,7 @@ from graphrag.index.operations.summarize_descriptions.typing import (
|
||||
)
|
||||
from graphrag.logger.progress import ProgressTicker, progress_ticker
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def summarize_descriptions(
|
||||
@ -29,7 +29,7 @@ async def summarize_descriptions(
|
||||
num_threads: int = 4,
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""Summarize entity and relationship descriptions from an entity graph, using a language model."""
|
||||
log.debug("summarize_descriptions strategy=%s", strategy)
|
||||
logger.debug("summarize_descriptions strategy=%s", strategy)
|
||||
strategy = strategy or {}
|
||||
strategy_exec = load_strategy(
|
||||
strategy.get("type", SummarizeStrategyType.graph_intelligence)
|
||||
@ -41,7 +41,11 @@ async def summarize_descriptions(
|
||||
):
|
||||
ticker_length = len(nodes) + len(edges)
|
||||
|
||||
ticker = progress_ticker(callbacks.progress, ticker_length)
|
||||
ticker = progress_ticker(
|
||||
callbacks.progress,
|
||||
ticker_length,
|
||||
description="Summarize entity/relationship description progress: ",
|
||||
)
|
||||
|
||||
node_futures = [
|
||||
do_summarize_descriptions(
|
||||
@ -95,9 +99,7 @@ async def summarize_descriptions(
|
||||
semaphore: asyncio.Semaphore,
|
||||
):
|
||||
async with semaphore:
|
||||
results = await strategy_exec(
|
||||
id, descriptions, callbacks, cache, strategy_config
|
||||
)
|
||||
results = await strategy_exec(id, descriptions, cache, strategy_config)
|
||||
ticker(1)
|
||||
return results
|
||||
|
||||
|
||||
@ -9,7 +9,6 @@ from enum import Enum
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
|
||||
StrategyConfig = dict[str, Any]
|
||||
|
||||
@ -26,7 +25,6 @@ SummarizationStrategy = Callable[
|
||||
[
|
||||
str | tuple[str, str],
|
||||
list[str],
|
||||
WorkflowCallbacks,
|
||||
PipelineCache,
|
||||
StrategyConfig,
|
||||
],
|
||||
|
||||
@ -7,7 +7,6 @@ import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import AsyncIterable
|
||||
from dataclasses import asdict
|
||||
|
||||
@ -17,20 +16,17 @@ from graphrag.index.run.utils import create_run_context
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.pipeline import Pipeline
|
||||
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.progress import Progress
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
from graphrag.utils.api import create_cache_from_config, create_storage_from_config
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_pipeline(
|
||||
pipeline: Pipeline,
|
||||
config: GraphRagConfig,
|
||||
callbacks: WorkflowCallbacks,
|
||||
logger: ProgressLogger,
|
||||
is_update_run: bool = False,
|
||||
) -> AsyncIterable[PipelineRunResult]:
|
||||
"""Run all workflows using a simplified pipeline."""
|
||||
@ -66,7 +62,6 @@ async def run_pipeline(
|
||||
cache=cache,
|
||||
callbacks=callbacks,
|
||||
state=state,
|
||||
progress_logger=logger,
|
||||
)
|
||||
|
||||
else:
|
||||
@ -78,13 +73,11 @@ async def run_pipeline(
|
||||
cache=cache,
|
||||
callbacks=callbacks,
|
||||
state=state,
|
||||
progress_logger=logger,
|
||||
)
|
||||
|
||||
async for table in _run_pipeline(
|
||||
pipeline=pipeline,
|
||||
config=config,
|
||||
logger=logger,
|
||||
context=context,
|
||||
):
|
||||
yield table
|
||||
@ -93,7 +86,6 @@ async def run_pipeline(
|
||||
async def _run_pipeline(
|
||||
pipeline: Pipeline,
|
||||
config: GraphRagConfig,
|
||||
logger: ProgressLogger,
|
||||
context: PipelineRunContext,
|
||||
) -> AsyncIterable[PipelineRunResult]:
|
||||
start_time = time.time()
|
||||
@ -103,13 +95,12 @@ async def _run_pipeline(
|
||||
try:
|
||||
await _dump_json(context)
|
||||
|
||||
logger.info("Executing pipeline...")
|
||||
for name, workflow_function in pipeline.run():
|
||||
last_workflow = name
|
||||
progress = logger.child(name, transient=False)
|
||||
context.callbacks.workflow_start(name, None)
|
||||
work_time = time.time()
|
||||
result = await workflow_function(config, context)
|
||||
progress(Progress(percent=1))
|
||||
context.callbacks.workflow_end(name, result)
|
||||
yield PipelineRunResult(
|
||||
workflow=name, result=result.result, state=context.state, errors=None
|
||||
@ -120,11 +111,11 @@ async def _run_pipeline(
|
||||
break
|
||||
|
||||
context.stats.total_runtime = time.time() - start_time
|
||||
logger.info("Indexing pipeline complete.")
|
||||
await _dump_json(context)
|
||||
|
||||
except Exception as e:
|
||||
log.exception("error running workflow %s", last_workflow)
|
||||
context.callbacks.error("Error running pipeline!", e, traceback.format_exc())
|
||||
logger.exception("error running workflow %s", last_workflow)
|
||||
yield PipelineRunResult(
|
||||
workflow=last_workflow, result=None, state=context.state, errors=[e]
|
||||
)
|
||||
|
||||
@ -6,15 +6,12 @@
|
||||
from graphrag.cache.memory_pipeline_cache import InMemoryCache
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.callbacks.progress_workflow_callbacks import ProgressWorkflowCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks_manager import WorkflowCallbacksManager
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.state import PipelineState
|
||||
from graphrag.index.typing.stats import PipelineRunStats
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.null_progress import NullProgressLogger
|
||||
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
from graphrag.utils.api import create_storage_from_config
|
||||
@ -26,7 +23,6 @@ def create_run_context(
|
||||
previous_storage: PipelineStorage | None = None,
|
||||
cache: PipelineCache | None = None,
|
||||
callbacks: WorkflowCallbacks | None = None,
|
||||
progress_logger: ProgressLogger | None = None,
|
||||
stats: PipelineRunStats | None = None,
|
||||
state: PipelineState | None = None,
|
||||
) -> PipelineRunContext:
|
||||
@ -37,21 +33,18 @@ def create_run_context(
|
||||
previous_storage=previous_storage or MemoryPipelineStorage(),
|
||||
cache=cache or InMemoryCache(),
|
||||
callbacks=callbacks or NoopWorkflowCallbacks(),
|
||||
progress_logger=progress_logger or NullProgressLogger(),
|
||||
stats=stats or PipelineRunStats(),
|
||||
state=state or {},
|
||||
)
|
||||
|
||||
|
||||
def create_callback_chain(
|
||||
callbacks: list[WorkflowCallbacks] | None, progress: ProgressLogger | None
|
||||
callbacks: list[WorkflowCallbacks] | None,
|
||||
) -> WorkflowCallbacks:
|
||||
"""Create a callback manager that encompasses multiple callbacks."""
|
||||
manager = WorkflowCallbacksManager()
|
||||
for callback in callbacks or []:
|
||||
manager.register(callback)
|
||||
if progress is not None:
|
||||
manager.register(ProgressWorkflowCallbacks(progress))
|
||||
return manager
|
||||
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ DecodeFn = Callable[[EncodedText], str]
|
||||
EncodeFn = Callable[[str], EncodedText]
|
||||
LengthFn = Callable[[str], int]
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -100,7 +100,9 @@ class TokenTextSplitter(TextSplitter):
|
||||
try:
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
log.exception("Model %s not found, using %s", model_name, encoding_name)
|
||||
logger.exception(
|
||||
"Model %s not found, using %s", model_name, encoding_name
|
||||
)
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
else:
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
|
||||
@ -10,7 +10,6 @@ from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.typing.state import PipelineState
|
||||
from graphrag.index.typing.stats import PipelineRunStats
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
|
||||
@ -29,7 +28,5 @@ class PipelineRunContext:
|
||||
"Cache instance for reading previous LLM responses."
|
||||
callbacks: WorkflowCallbacks
|
||||
"Callbacks to be called during the pipeline run."
|
||||
progress_logger: ProgressLogger
|
||||
"Progress logger for the pipeline run."
|
||||
state: PipelineState
|
||||
"Arbitrary property bag for runtime state, persistent pre-computes, or experimental features."
|
||||
|
||||
@ -37,17 +37,18 @@ async def derive_from_rows(
|
||||
callbacks: WorkflowCallbacks | None = None,
|
||||
num_threads: int = 4,
|
||||
async_type: AsyncType = AsyncType.AsyncIO,
|
||||
progress_msg: str = "",
|
||||
) -> list[ItemType | None]:
|
||||
"""Apply a generic transform function to each row. Any errors will be reported and thrown."""
|
||||
callbacks = callbacks or NoopWorkflowCallbacks()
|
||||
match async_type:
|
||||
case AsyncType.AsyncIO:
|
||||
return await derive_from_rows_asyncio(
|
||||
input, transform, callbacks, num_threads
|
||||
input, transform, callbacks, num_threads, progress_msg
|
||||
)
|
||||
case AsyncType.Threaded:
|
||||
return await derive_from_rows_asyncio_threads(
|
||||
input, transform, callbacks, num_threads
|
||||
input, transform, callbacks, num_threads, progress_msg
|
||||
)
|
||||
case _:
|
||||
msg = f"Unsupported scheduling type {async_type}"
|
||||
@ -62,6 +63,7 @@ async def derive_from_rows_asyncio_threads(
|
||||
transform: Callable[[pd.Series], Awaitable[ItemType]],
|
||||
callbacks: WorkflowCallbacks,
|
||||
num_threads: int | None = 4,
|
||||
progress_msg: str = "",
|
||||
) -> list[ItemType | None]:
|
||||
"""
|
||||
Derive from rows asynchronously.
|
||||
@ -81,7 +83,9 @@ async def derive_from_rows_asyncio_threads(
|
||||
|
||||
return await asyncio.gather(*[execute_task(task) for task in tasks])
|
||||
|
||||
return await _derive_from_rows_base(input, transform, callbacks, gather)
|
||||
return await _derive_from_rows_base(
|
||||
input, transform, callbacks, gather, progress_msg
|
||||
)
|
||||
|
||||
|
||||
"""A module containing the derive_from_rows_async method."""
|
||||
@ -92,6 +96,7 @@ async def derive_from_rows_asyncio(
|
||||
transform: Callable[[pd.Series], Awaitable[ItemType]],
|
||||
callbacks: WorkflowCallbacks,
|
||||
num_threads: int = 4,
|
||||
progress_msg: str = "",
|
||||
) -> list[ItemType | None]:
|
||||
"""
|
||||
Derive from rows asynchronously.
|
||||
@ -112,7 +117,9 @@ async def derive_from_rows_asyncio(
|
||||
]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
return await _derive_from_rows_base(input, transform, callbacks, gather)
|
||||
return await _derive_from_rows_base(
|
||||
input, transform, callbacks, gather, progress_msg
|
||||
)
|
||||
|
||||
|
||||
ItemType = TypeVar("ItemType")
|
||||
@ -126,13 +133,16 @@ async def _derive_from_rows_base(
|
||||
transform: Callable[[pd.Series], Awaitable[ItemType]],
|
||||
callbacks: WorkflowCallbacks,
|
||||
gather: GatherFn[ItemType],
|
||||
progress_msg: str = "",
|
||||
) -> list[ItemType | None]:
|
||||
"""
|
||||
Derive from rows asynchronously.
|
||||
|
||||
This is useful for IO bound operations.
|
||||
"""
|
||||
tick = progress_ticker(callbacks.progress, num_total=len(input))
|
||||
tick = progress_ticker(
|
||||
callbacks.progress, num_total=len(input), description=progress_msg
|
||||
)
|
||||
errors: list[tuple[BaseException, str]] = []
|
||||
|
||||
async def execute(row: tuple[Any, pd.Series]) -> ItemType | None:
|
||||
@ -153,7 +163,9 @@ async def _derive_from_rows_base(
|
||||
tick.done()
|
||||
|
||||
for error, stack in errors:
|
||||
callbacks.error("parallel transformation error", error, stack)
|
||||
logger.error(
|
||||
"parallel transformation error", exc_info=error, extra={"stack": stack}
|
||||
)
|
||||
|
||||
if len(errors) > 0:
|
||||
raise ParallelizationError(len(errors), errors[0][1])
|
||||
|
||||
@ -11,7 +11,7 @@ import graphrag.config.defaults as defs
|
||||
|
||||
DEFAULT_ENCODING_NAME = defs.ENCODING_MODEL
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def num_tokens_from_string(
|
||||
@ -23,7 +23,7 @@ def num_tokens_from_string(
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
msg = f"Failed to get encoding for {model} when getting num_tokens_from_string. Fall back to default encoding {DEFAULT_ENCODING_NAME}"
|
||||
log.warning(msg)
|
||||
logger.warning(msg)
|
||||
encoding = tiktoken.get_encoding(DEFAULT_ENCODING_NAME)
|
||||
else:
|
||||
encoding = tiktoken.get_encoding(encoding_name or DEFAULT_ENCODING_NAME)
|
||||
|
||||
@ -4,15 +4,17 @@
|
||||
"""A module containing validate_config_names definition."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.logger.print_progress import ProgressLogger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) -> None:
|
||||
def validate_config_names(parameters: GraphRagConfig) -> None:
|
||||
"""Validate config file for LLM deployment name typos."""
|
||||
# Validate Chat LLM configs
|
||||
# TODO: Replace default_chat_model with a way to select the model
|
||||
@ -28,7 +30,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
|
||||
|
||||
try:
|
||||
asyncio.run(llm.achat("This is an LLM connectivity test. Say Hello World"))
|
||||
logger.success("LLM Config Params Validated")
|
||||
logger.info("LLM Config Params Validated")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"LLM configuration error detected. Exiting...\n{e}") # noqa
|
||||
sys.exit(1)
|
||||
@ -48,7 +50,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
|
||||
|
||||
try:
|
||||
asyncio.run(embed_llm.aembed_batch(["This is an LLM Embedding Test String"]))
|
||||
logger.success("Embedding LLM Config Params Validated")
|
||||
logger.info("Embedding LLM Config Params Validated")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Embedding LLM configuration error detected. Exiting...\n{e}") # noqa
|
||||
sys.exit(1)
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
"""A module containing run_workflow method definition."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
@ -19,12 +20,15 @@ from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
from graphrag.logger.progress import Progress
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform base text_units."""
|
||||
logger.info("Workflow started: create_base_text_units")
|
||||
documents = await load_table_from_storage("documents", context.output_storage)
|
||||
|
||||
chunks = config.chunks
|
||||
@ -43,6 +47,7 @@ async def run_workflow(
|
||||
|
||||
await write_table_to_storage(output, "text_units", context.output_storage)
|
||||
|
||||
logger.info("Workflow completed: create_base_text_units")
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
@ -81,7 +86,7 @@ def create_base_text_units(
|
||||
)
|
||||
aggregated.rename(columns={"text_with_ids": "texts"}, inplace=True)
|
||||
|
||||
def chunker(row: dict[str, Any]) -> Any:
|
||||
def chunker(row: pd.Series) -> Any:
|
||||
line_delimiter = ".\n"
|
||||
metadata_str = ""
|
||||
metadata_tokens = 0
|
||||
@ -125,7 +130,19 @@ def create_base_text_units(
|
||||
row["chunks"] = chunked
|
||||
return row
|
||||
|
||||
aggregated = aggregated.apply(lambda row: chunker(row), axis=1)
|
||||
# Track progress of row-wise apply operation
|
||||
total_rows = len(aggregated)
|
||||
logger.info("Starting chunking process for %d documents", total_rows)
|
||||
|
||||
def chunker_with_logging(row: pd.Series, row_index: int) -> Any:
|
||||
"""Add logging to chunker execution."""
|
||||
result = chunker(row)
|
||||
logger.info("chunker progress: %d/%d", row_index + 1, total_rows)
|
||||
return result
|
||||
|
||||
aggregated = aggregated.apply(
|
||||
lambda row: chunker_with_logging(row, row.name), axis=1
|
||||
)
|
||||
|
||||
aggregated = cast("pd.DataFrame", aggregated[[*group_by_columns, "chunks"]])
|
||||
aggregated = aggregated.explode("chunks")
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
"""A module containing run_workflow method definition."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
@ -18,12 +19,15 @@ from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform final communities."""
|
||||
logger.info("Workflow started: create_communities")
|
||||
entities = await load_table_from_storage("entities", context.output_storage)
|
||||
relationships = await load_table_from_storage(
|
||||
"relationships", context.output_storage
|
||||
@ -43,6 +47,7 @@ async def run_workflow(
|
||||
|
||||
await write_table_to_storage(output, "communities", context.output_storage)
|
||||
|
||||
logger.info("Workflow completed: create_communities")
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
|
||||
"""A module containing run_workflow method definition."""
|
||||
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import graphrag.data_model.schemas as schemas
|
||||
@ -32,12 +34,15 @@ from graphrag.utils.storage import (
|
||||
write_table_to_storage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform community reports."""
|
||||
logger.info("Workflow started: create_community_reports")
|
||||
edges = await load_table_from_storage("relationships", context.output_storage)
|
||||
entities = await load_table_from_storage("entities", context.output_storage)
|
||||
communities = await load_table_from_storage("communities", context.output_storage)
|
||||
@ -70,6 +75,7 @@ async def run_workflow(
|
||||
|
||||
await write_table_to_storage(output, "community_reports", context.output_storage)
|
||||
|
||||
logger.info("Workflow completed: create_community_reports")
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
@ -37,6 +37,7 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform community reports."""
|
||||
logger.info("Workflow started: create_community_reports_text")
|
||||
entities = await load_table_from_storage("entities", context.output_storage)
|
||||
communities = await load_table_from_storage("communities", context.output_storage)
|
||||
|
||||
@ -64,6 +65,7 @@ async def run_workflow(
|
||||
|
||||
await write_table_to_storage(output, "community_reports", context.output_storage)
|
||||
|
||||
logger.info("Workflow completed: create_community_reports_text")
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
|
||||
"""A module containing run_workflow method definition."""
|
||||
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
@ -11,12 +13,15 @@ from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
_config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform final documents."""
|
||||
logger.info("Workflow started: create_final_documents")
|
||||
documents = await load_table_from_storage("documents", context.output_storage)
|
||||
text_units = await load_table_from_storage("text_units", context.output_storage)
|
||||
|
||||
@ -24,6 +29,7 @@ async def run_workflow(
|
||||
|
||||
await write_table_to_storage(output, "documents", context.output_storage)
|
||||
|
||||
logger.info("Workflow completed: create_final_documents")
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
|
||||
"""A module containing run_workflow method definition."""
|
||||
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
@ -15,12 +17,15 @@ from graphrag.utils.storage import (
|
||||
write_table_to_storage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform the text units."""
|
||||
logger.info("Workflow started: create_final_text_units")
|
||||
text_units = await load_table_from_storage("text_units", context.output_storage)
|
||||
final_entities = await load_table_from_storage("entities", context.output_storage)
|
||||
final_relationships = await load_table_from_storage(
|
||||
@ -43,6 +48,7 @@ async def run_workflow(
|
||||
|
||||
await write_table_to_storage(output, "text_units", context.output_storage)
|
||||
|
||||
logger.info("Workflow completed: create_final_text_units")
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
"""A module containing run_workflow method definition."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
@ -20,12 +21,15 @@ from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to extract and format covariates."""
|
||||
logger.info("Workflow started: extract_covariates")
|
||||
output = None
|
||||
if config.extract_claims.enabled:
|
||||
text_units = await load_table_from_storage("text_units", context.output_storage)
|
||||
@ -53,6 +57,7 @@ async def run_workflow(
|
||||
|
||||
await write_table_to_storage(output, "covariates", context.output_storage)
|
||||
|
||||
logger.info("Workflow completed: extract_covariates")
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
"""A module containing run_workflow method definition."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
@ -21,12 +22,15 @@ from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to create the base entity graph."""
|
||||
logger.info("Workflow started: extract_graph")
|
||||
text_units = await load_table_from_storage("text_units", context.output_storage)
|
||||
|
||||
extract_graph_llm_settings = config.get_language_model_config(
|
||||
@ -66,6 +70,7 @@ async def run_workflow(
|
||||
raw_relationships, "raw_relationships", context.output_storage
|
||||
)
|
||||
|
||||
logger.info("Workflow completed: extract_graph")
|
||||
return WorkflowFunctionOutput(
|
||||
result={
|
||||
"entities": entities,
|
||||
@ -101,14 +106,14 @@ async def extract_graph(
|
||||
|
||||
if not _validate_data(extracted_entities):
|
||||
error_msg = "Entity Extraction failed. No entities detected during extraction."
|
||||
callbacks.error(error_msg)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if not _validate_data(extracted_relationships):
|
||||
error_msg = (
|
||||
"Entity Extraction failed. No relationships detected during extraction."
|
||||
)
|
||||
callbacks.error(error_msg)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# copy these as is before any summarization
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
|
||||
"""A module containing run_workflow method definition."""
|
||||
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
@ -16,12 +18,15 @@ from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to create the base entity graph."""
|
||||
logger.info("Workflow started: extract_graph_nlp")
|
||||
text_units = await load_table_from_storage("text_units", context.output_storage)
|
||||
|
||||
entities, relationships = await extract_graph_nlp(
|
||||
@ -33,6 +38,8 @@ async def run_workflow(
|
||||
await write_table_to_storage(entities, "entities", context.output_storage)
|
||||
await write_table_to_storage(relationships, "relationships", context.output_storage)
|
||||
|
||||
logger.info("Workflow completed: extract_graph_nlp")
|
||||
|
||||
return WorkflowFunctionOutput(
|
||||
result={
|
||||
"entities": entities,
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
"""Encapsulates pipeline construction and selection."""
|
||||
|
||||
import logging
|
||||
from typing import ClassVar
|
||||
|
||||
from graphrag.config.enums import IndexingMethod
|
||||
@ -10,6 +11,8 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.typing.pipeline import Pipeline
|
||||
from graphrag.index.typing.workflow import WorkflowFunction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineFactory:
|
||||
"""A factory class for workflow pipelines."""
|
||||
@ -41,6 +44,7 @@ class PipelineFactory:
|
||||
) -> Pipeline:
|
||||
"""Create a pipeline generator."""
|
||||
workflows = config.workflows or cls.pipelines.get(method, [])
|
||||
logger.info("Creating pipeline with workflows: %s", workflows)
|
||||
return Pipeline([(name, cls.workflows[name]) for name in workflows])
|
||||
|
||||
|
||||
|
||||
@ -3,9 +3,10 @@
|
||||
|
||||
"""A module containing run_workflow method definition."""
|
||||
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.operations.create_graph import create_graph
|
||||
@ -16,12 +17,15 @@ from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to create the base entity graph."""
|
||||
logger.info("Workflow started: finalize_graph")
|
||||
entities = await load_table_from_storage("entities", context.output_storage)
|
||||
relationships = await load_table_from_storage(
|
||||
"relationships", context.output_storage
|
||||
@ -30,7 +34,6 @@ async def run_workflow(
|
||||
final_entities, final_relationships = finalize_graph(
|
||||
entities,
|
||||
relationships,
|
||||
callbacks=context.callbacks,
|
||||
embed_config=config.embed_graph,
|
||||
layout_enabled=config.umap.enabled,
|
||||
)
|
||||
@ -50,6 +53,7 @@ async def run_workflow(
|
||||
storage=context.output_storage,
|
||||
)
|
||||
|
||||
logger.info("Workflow completed: finalize_graph")
|
||||
return WorkflowFunctionOutput(
|
||||
result={
|
||||
"entities": entities,
|
||||
@ -61,13 +65,12 @@ async def run_workflow(
|
||||
def finalize_graph(
|
||||
entities: pd.DataFrame,
|
||||
relationships: pd.DataFrame,
|
||||
callbacks: WorkflowCallbacks,
|
||||
embed_config: EmbedGraphConfig | None = None,
|
||||
layout_enabled: bool = False,
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""All the steps to finalize the entity and relationship formats."""
|
||||
final_entities = finalize_entities(
|
||||
entities, relationships, callbacks, embed_config, layout_enabled
|
||||
entities, relationships, embed_config, layout_enabled
|
||||
)
|
||||
final_relationships = finalize_relationships(relationships)
|
||||
return (final_entities, final_relationships)
|
||||
|
||||
@ -30,7 +30,7 @@ from graphrag.utils.storage import (
|
||||
write_table_to_storage,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
@ -38,6 +38,7 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform community reports."""
|
||||
logger.info("Workflow started: generate_text_embeddings")
|
||||
documents = None
|
||||
relationships = None
|
||||
text_units = None
|
||||
@ -81,6 +82,7 @@ async def run_workflow(
|
||||
context.output_storage,
|
||||
)
|
||||
|
||||
logger.info("Workflow completed: generate_text_embeddings")
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
@ -145,12 +147,12 @@ async def generate_text_embeddings(
|
||||
},
|
||||
}
|
||||
|
||||
log.info("Creating embeddings")
|
||||
logger.info("Creating embeddings")
|
||||
outputs = {}
|
||||
for field in embedded_fields:
|
||||
if embedding_param_map[field]["data"] is None:
|
||||
msg = f"Embedding {field} is specified but data table is not in storage. This may or may not be intentional - if you expect it to me here, please check for errors earlier in the logs."
|
||||
log.warning(msg)
|
||||
logger.warning(msg)
|
||||
else:
|
||||
outputs[field] = await _run_embeddings(
|
||||
name=field,
|
||||
|
||||
@ -12,11 +12,10 @@ from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.factory import create_input
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
from graphrag.utils.storage import write_table_to_storage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
@ -27,10 +26,9 @@ async def run_workflow(
|
||||
output = await load_input_documents(
|
||||
config.input,
|
||||
context.input_storage,
|
||||
context.progress_logger,
|
||||
)
|
||||
|
||||
log.info("Final # of rows loaded: %s", len(output))
|
||||
logger.info("Final # of rows loaded: %s", len(output))
|
||||
context.stats.num_documents = len(output)
|
||||
|
||||
await write_table_to_storage(output, "documents", context.output_storage)
|
||||
@ -39,7 +37,7 @@ async def run_workflow(
|
||||
|
||||
|
||||
async def load_input_documents(
|
||||
config: InputConfig, storage: PipelineStorage, progress_logger: ProgressLogger
|
||||
config: InputConfig, storage: PipelineStorage
|
||||
) -> pd.DataFrame:
|
||||
"""Load and parse input documents into a standard format."""
|
||||
return await create_input(config, storage, progress_logger)
|
||||
return await create_input(config, storage)
|
||||
|
||||
@ -13,11 +13,10 @@ from graphrag.index.input.factory import create_input
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.index.update.incremental_index import get_delta_docs
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
from graphrag.utils.storage import write_table_to_storage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
@ -29,15 +28,13 @@ async def run_workflow(
|
||||
config.input,
|
||||
context.input_storage,
|
||||
context.previous_storage,
|
||||
context.progress_logger,
|
||||
)
|
||||
|
||||
log.info("Final # of update rows loaded: %s", len(output))
|
||||
logger.info("Final # of update rows loaded: %s", len(output))
|
||||
context.stats.update_documents = len(output)
|
||||
|
||||
if len(output) == 0:
|
||||
log.warning("No new update documents found.")
|
||||
context.progress_logger.warning("No new update documents found.")
|
||||
logger.warning("No new update documents found.")
|
||||
return WorkflowFunctionOutput(result=None, stop=True)
|
||||
|
||||
await write_table_to_storage(output, "documents", context.output_storage)
|
||||
@ -49,10 +46,9 @@ async def load_update_documents(
|
||||
config: InputConfig,
|
||||
input_storage: PipelineStorage,
|
||||
previous_storage: PipelineStorage,
|
||||
progress_logger: ProgressLogger,
|
||||
) -> pd.DataFrame:
|
||||
"""Load and parse update-only input documents into a standard format."""
|
||||
input_documents = await create_input(config, input_storage, progress_logger)
|
||||
input_documents = await create_input(config, input_storage)
|
||||
# previous storage is the output of the previous run
|
||||
# we'll use this to diff the input from the prior
|
||||
delta_documents = await get_delta_docs(input_documents, previous_storage)
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
|
||||
"""A module containing run_workflow method definition."""
|
||||
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
@ -14,12 +16,15 @@ from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to create the base entity graph."""
|
||||
logger.info("Workflow started: prune_graph")
|
||||
entities = await load_table_from_storage("entities", context.output_storage)
|
||||
relationships = await load_table_from_storage(
|
||||
"relationships", context.output_storage
|
||||
@ -36,6 +41,7 @@ async def run_workflow(
|
||||
pruned_relationships, "relationships", context.output_storage
|
||||
)
|
||||
|
||||
logger.info("Workflow completed: prune_graph")
|
||||
return WorkflowFunctionOutput(
|
||||
result={
|
||||
"entities": pruned_entities,
|
||||
|
||||
@ -17,7 +17,7 @@ async def run_workflow( # noqa: RUF029
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""Clean the state after the update."""
|
||||
logger.info("Cleaning State")
|
||||
logger.info("Workflow started: update_clean_state")
|
||||
keys_to_delete = [
|
||||
key_name
|
||||
for key_name in context.state
|
||||
@ -27,4 +27,5 @@ async def run_workflow( # noqa: RUF029
|
||||
for key_name in keys_to_delete:
|
||||
del context.state[key_name]
|
||||
|
||||
logger.info("Workflow completed: update_clean_state")
|
||||
return WorkflowFunctionOutput(result=None)
|
||||
|
||||
@ -21,7 +21,7 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""Update the communities from a incremental index run."""
|
||||
logger.info("Updating Communities")
|
||||
logger.info("Workflow started: update_communities")
|
||||
output_storage, previous_storage, delta_storage = get_update_storages(
|
||||
config, context.state["update_timestamp"]
|
||||
)
|
||||
@ -32,6 +32,7 @@ async def run_workflow(
|
||||
|
||||
context.state["incremental_update_community_id_mapping"] = community_id_mapping
|
||||
|
||||
logger.info("Workflow completed: update_communities")
|
||||
return WorkflowFunctionOutput(result=None)
|
||||
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""Update the community reports from a incremental index run."""
|
||||
logger.info("Updating Community Reports")
|
||||
logger.info("Workflow started: update_community_reports")
|
||||
output_storage, previous_storage, delta_storage = get_update_storages(
|
||||
config, context.state["update_timestamp"]
|
||||
)
|
||||
@ -38,6 +38,7 @@ async def run_workflow(
|
||||
merged_community_reports
|
||||
)
|
||||
|
||||
logger.info("Workflow completed: update_community_reports")
|
||||
return WorkflowFunctionOutput(result=None)
|
||||
|
||||
|
||||
|
||||
@ -27,6 +27,7 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""Update the covariates from a incremental index run."""
|
||||
logger.info("Workflow started: update_covariates")
|
||||
output_storage, previous_storage, delta_storage = get_update_storages(
|
||||
config, context.state["update_timestamp"]
|
||||
)
|
||||
@ -37,6 +38,7 @@ async def run_workflow(
|
||||
logger.info("Updating Covariates")
|
||||
await _update_covariates(previous_storage, delta_storage, output_storage)
|
||||
|
||||
logger.info("Workflow completed: update_covariates")
|
||||
return WorkflowFunctionOutput(result=None)
|
||||
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""Update the entities and relationships from a incremental index run."""
|
||||
logger.info("Updating Entities and Relationships")
|
||||
logger.info("Workflow started: update_entities_relationships")
|
||||
output_storage, previous_storage, delta_storage = get_update_storages(
|
||||
config, context.state["update_timestamp"]
|
||||
)
|
||||
@ -49,6 +49,7 @@ async def run_workflow(
|
||||
context.state["incremental_update_merged_relationships"] = merged_relationships_df
|
||||
context.state["incremental_update_entity_id_mapping"] = entity_id_mapping
|
||||
|
||||
logger.info("Workflow completed: update_entities_relationships")
|
||||
return WorkflowFunctionOutput(result=None)
|
||||
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""Update the documents from a incremental index run."""
|
||||
logger.info("Updating Documents")
|
||||
logger.info("Workflow started: update_final_documents")
|
||||
output_storage, previous_storage, delta_storage = get_update_storages(
|
||||
config, context.state["update_timestamp"]
|
||||
)
|
||||
@ -30,4 +30,5 @@ async def run_workflow(
|
||||
|
||||
context.state["incremental_update_final_documents"] = final_documents
|
||||
|
||||
logger.info("Workflow completed: update_final_documents")
|
||||
return WorkflowFunctionOutput(result=None)
|
||||
|
||||
@ -21,7 +21,7 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""Update the text embeddings from a incremental index run."""
|
||||
logger.info("Updating Text Embeddings")
|
||||
logger.info("Workflow started: update_text_embeddings")
|
||||
output_storage, _, _ = get_update_storages(
|
||||
config, context.state["update_timestamp"]
|
||||
)
|
||||
@ -55,4 +55,5 @@ async def run_workflow(
|
||||
output_storage,
|
||||
)
|
||||
|
||||
logger.info("Workflow completed: update_text_embeddings")
|
||||
return WorkflowFunctionOutput(result=None)
|
||||
|
||||
@ -23,7 +23,7 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""Update the text units from a incremental index run."""
|
||||
logger.info("Updating Text Units")
|
||||
logger.info("Workflow started: update_text_units")
|
||||
output_storage, previous_storage, delta_storage = get_update_storages(
|
||||
config, context.state["update_timestamp"]
|
||||
)
|
||||
@ -35,6 +35,7 @@ async def run_workflow(
|
||||
|
||||
context.state["incremental_update_merged_text_units"] = merged_text_units
|
||||
|
||||
logger.info("Workflow completed: update_text_units")
|
||||
return WorkflowFunctionOutput(result=None)
|
||||
|
||||
|
||||
|
||||
@ -11,6 +11,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from graphrag.language_model.factory import ModelFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -22,11 +24,11 @@ class ModelManager:
|
||||
|
||||
_instance: ClassVar[ModelManager | None] = None
|
||||
|
||||
def __new__(cls) -> ModelManager: # noqa: PYI034: False positive
|
||||
def __new__(cls) -> Self:
|
||||
"""Create a new instance of LLMManager if it does not exist."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
return cls._instance # type: ignore[return-value]
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Avoid reinitialization in the singleton.
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
@ -26,6 +27,8 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from graphrag.index.typing.error_handler import ErrorHandlerFn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_cache(cache: PipelineCache | None, name: str) -> FNLLMCacheProvider | None:
|
||||
"""Create an FNLLM cache from a pipeline cache."""
|
||||
@ -34,7 +37,7 @@ def _create_cache(cache: PipelineCache | None, name: str) -> FNLLMCacheProvider
|
||||
return FNLLMCacheProvider(cache).child(name)
|
||||
|
||||
|
||||
def _create_error_handler(callbacks: WorkflowCallbacks) -> ErrorHandlerFn:
|
||||
def _create_error_handler(callbacks: WorkflowCallbacks) -> ErrorHandlerFn: # noqa: ARG001
|
||||
"""Create an error handler from a WorkflowCallbacks."""
|
||||
|
||||
def on_error(
|
||||
@ -42,7 +45,11 @@ def _create_error_handler(callbacks: WorkflowCallbacks) -> ErrorHandlerFn:
|
||||
stack: str | None = None,
|
||||
details: dict | None = None,
|
||||
) -> None:
|
||||
callbacks.error("Error Invoking LLM", error, stack, details)
|
||||
logger.error(
|
||||
"Error Invoking LLM",
|
||||
exc_info=error,
|
||||
extra={"stack": stack, "details": details},
|
||||
)
|
||||
|
||||
return on_error
|
||||
|
||||
|
||||
@ -1,69 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Base classes for logging and progress reporting."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from graphrag.logger.progress import Progress
|
||||
|
||||
|
||||
class StatusLogger(ABC):
|
||||
"""Provides a way to log status updates from the pipeline."""
|
||||
|
||||
@abstractmethod
|
||||
def error(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Log an error."""
|
||||
|
||||
@abstractmethod
|
||||
def warning(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Log a warning."""
|
||||
|
||||
@abstractmethod
|
||||
def log(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Report a log."""
|
||||
|
||||
|
||||
class ProgressLogger(ABC):
|
||||
"""
|
||||
Abstract base class for progress loggers.
|
||||
|
||||
This is used to report workflow processing progress via mechanisms like progress-bars.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, update: Progress):
|
||||
"""Update progress."""
|
||||
|
||||
@abstractmethod
|
||||
def dispose(self):
|
||||
"""Dispose of the progress logger."""
|
||||
|
||||
@abstractmethod
|
||||
def child(self, prefix: str, transient=True) -> "ProgressLogger":
|
||||
"""Create a child progress bar."""
|
||||
|
||||
@abstractmethod
|
||||
def force_refresh(self) -> None:
|
||||
"""Force a refresh."""
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""Stop the progress logger."""
|
||||
|
||||
@abstractmethod
|
||||
def error(self, message: str) -> None:
|
||||
"""Log an error."""
|
||||
|
||||
@abstractmethod
|
||||
def warning(self, message: str) -> None:
|
||||
"""Log a warning."""
|
||||
|
||||
@abstractmethod
|
||||
def info(self, message: str) -> None:
|
||||
"""Log information."""
|
||||
|
||||
@abstractmethod
|
||||
def success(self, message: str) -> None:
|
||||
"""Log success."""
|
||||
@ -4,6 +4,7 @@
|
||||
"""A logger that emits updates from the indexing engine to a blob in Azure Storage."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@ -11,11 +12,9 @@ from typing import Any
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
|
||||
|
||||
class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
"""A logger that writes to a blob storage account."""
|
||||
class BlobWorkflowLogger(logging.Handler):
|
||||
"""A logging handler that writes to a blob storage account."""
|
||||
|
||||
_blob_service_client: BlobServiceClient
|
||||
_container_name: str
|
||||
@ -28,16 +27,21 @@ class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
blob_name: str = "",
|
||||
base_dir: str | None = None,
|
||||
storage_account_blob_url: str | None = None,
|
||||
level: int = logging.NOTSET,
|
||||
):
|
||||
"""Create a new instance of the BlobStorageReporter class."""
|
||||
"""Create a new instance of the BlobWorkflowLogger class."""
|
||||
super().__init__(level)
|
||||
|
||||
if container_name is None:
|
||||
msg = "No container name provided for blob storage."
|
||||
raise ValueError(msg)
|
||||
if connection_string is None and storage_account_blob_url is None:
|
||||
msg = "No storage account blob url provided for blob storage."
|
||||
raise ValueError(msg)
|
||||
|
||||
self._connection_string = connection_string
|
||||
self._storage_account_blob_url = storage_account_blob_url
|
||||
|
||||
if self._connection_string:
|
||||
self._blob_service_client = BlobServiceClient.from_connection_string(
|
||||
self._connection_string
|
||||
@ -65,7 +69,37 @@ class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
|
||||
self._num_blocks = 0 # refresh block counter
|
||||
|
||||
def emit(self, record) -> None:
|
||||
"""Emit a log record to blob storage."""
|
||||
try:
|
||||
# Create JSON structure based on record
|
||||
log_data = {
|
||||
"type": self._get_log_type(record.levelno),
|
||||
"data": record.getMessage(),
|
||||
}
|
||||
|
||||
# Add additional fields if they exist
|
||||
if hasattr(record, "details") and record.details: # type: ignore[reportAttributeAccessIssue]
|
||||
log_data["details"] = record.details # type: ignore[reportAttributeAccessIssue]
|
||||
if record.exc_info and record.exc_info[1]:
|
||||
log_data["cause"] = str(record.exc_info[1])
|
||||
if hasattr(record, "stack") and record.stack: # type: ignore[reportAttributeAccessIssue]
|
||||
log_data["stack"] = record.stack # type: ignore[reportAttributeAccessIssue]
|
||||
|
||||
self._write_log(log_data)
|
||||
except (OSError, ValueError):
|
||||
self.handleError(record)
|
||||
|
||||
def _get_log_type(self, level: int) -> str:
|
||||
"""Get log type string based on log level."""
|
||||
if level >= logging.ERROR:
|
||||
return "error"
|
||||
if level >= logging.WARNING:
|
||||
return "warning"
|
||||
return "log"
|
||||
|
||||
def _write_log(self, log: dict[str, Any]):
|
||||
"""Write log data to blob storage."""
|
||||
# create a new file when block count hits close 25k
|
||||
if (
|
||||
self._num_blocks >= self._max_block_count
|
||||
@ -83,27 +117,3 @@ class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
|
||||
# update the blob's block count
|
||||
self._num_blocks += 1
|
||||
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
stack: str | None = None,
|
||||
details: dict | None = None,
|
||||
):
|
||||
"""Report an error."""
|
||||
self._write_log({
|
||||
"type": "error",
|
||||
"data": message,
|
||||
"cause": str(cause),
|
||||
"stack": stack,
|
||||
"details": details,
|
||||
})
|
||||
|
||||
def warning(self, message: str, details: dict | None = None):
|
||||
"""Report a warning."""
|
||||
self._write_log({"type": "warning", "data": message, "details": details})
|
||||
|
||||
def log(self, message: str, details: dict | None = None):
|
||||
"""Report a generic log message."""
|
||||
self._write_log({"type": "log", "data": message, "details": details})
|
||||
@ -1,28 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Console Log."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from graphrag.logger.base import StatusLogger
|
||||
|
||||
|
||||
class ConsoleReporter(StatusLogger):
|
||||
"""A logger that writes to a console."""
|
||||
|
||||
def error(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Log an error."""
|
||||
print(message, details) # noqa T201
|
||||
|
||||
def warning(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Log a warning."""
|
||||
_print_warning(message)
|
||||
|
||||
def log(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Log a log."""
|
||||
print(message, details) # noqa T201
|
||||
|
||||
|
||||
def _print_warning(skk):
|
||||
print(f"\033[93m {skk}\033[00m") # noqa T201
|
||||
@ -1,43 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Factory functions for creating loggers."""
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.null_progress import NullProgressLogger
|
||||
from graphrag.logger.print_progress import PrintProgressLogger
|
||||
from graphrag.logger.rich_progress import RichProgressLogger
|
||||
from graphrag.logger.types import LoggerType
|
||||
|
||||
|
||||
class LoggerFactory:
|
||||
"""A factory class for loggers."""
|
||||
|
||||
logger_types: ClassVar[dict[str, type]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, logger_type: str, logger: type):
|
||||
"""Register a custom logger implementation."""
|
||||
cls.logger_types[logger_type] = logger
|
||||
|
||||
@classmethod
|
||||
def create_logger(
|
||||
cls, logger_type: LoggerType | str, kwargs: dict | None = None
|
||||
) -> ProgressLogger:
|
||||
"""Create a logger based on the provided type."""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
match logger_type:
|
||||
case LoggerType.RICH:
|
||||
return RichProgressLogger("GraphRAG Indexer ")
|
||||
case LoggerType.PRINT:
|
||||
return PrintProgressLogger("GraphRAG Indexer ")
|
||||
case LoggerType.NONE:
|
||||
return NullProgressLogger()
|
||||
case _:
|
||||
if logger_type in cls.logger_types:
|
||||
return cls.logger_types[logger_type](**kwargs)
|
||||
# default to null logger if no other logger is found
|
||||
return NullProgressLogger()
|
||||
@ -1,38 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Null Progress Reporter."""
|
||||
|
||||
from graphrag.logger.base import Progress, ProgressLogger
|
||||
|
||||
|
||||
class NullProgressLogger(ProgressLogger):
|
||||
"""A progress logger that does nothing."""
|
||||
|
||||
def __call__(self, update: Progress) -> None:
|
||||
"""Update progress."""
|
||||
|
||||
def dispose(self) -> None:
|
||||
"""Dispose of the progress logger."""
|
||||
|
||||
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
|
||||
"""Create a child progress bar."""
|
||||
return self
|
||||
|
||||
def force_refresh(self) -> None:
|
||||
"""Force a refresh."""
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the progress logger."""
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Log an error."""
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Log a warning."""
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""Log information."""
|
||||
|
||||
def success(self, message: str) -> None:
|
||||
"""Log success."""
|
||||
@ -1,50 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Print Progress Logger."""
|
||||
|
||||
from graphrag.logger.base import Progress, ProgressLogger
|
||||
|
||||
|
||||
class PrintProgressLogger(ProgressLogger):
|
||||
"""A progress logger that prints progress to stdout."""
|
||||
|
||||
prefix: str
|
||||
|
||||
def __init__(self, prefix: str):
|
||||
"""Create a new progress logger."""
|
||||
self.prefix = prefix
|
||||
print(f"\n{self.prefix}", end="") # noqa T201
|
||||
|
||||
def __call__(self, update: Progress) -> None:
|
||||
"""Update progress."""
|
||||
print(".", end="") # noqa T201
|
||||
|
||||
def dispose(self) -> None:
|
||||
"""Dispose of the progress logger."""
|
||||
|
||||
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
|
||||
"""Create a child progress bar."""
|
||||
return PrintProgressLogger(prefix)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the progress logger."""
|
||||
|
||||
def force_refresh(self) -> None:
|
||||
"""Force a refresh."""
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Log an error."""
|
||||
print(f"\n{self.prefix}ERROR: {message}") # noqa T201
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Log a warning."""
|
||||
print(f"\n{self.prefix}WARNING: {message}") # noqa T201
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""Log information."""
|
||||
print(f"\n{self.prefix}INFO: {message}") # noqa T201
|
||||
|
||||
def success(self, message: str) -> None:
|
||||
"""Log success."""
|
||||
print(f"\n{self.prefix}SUCCESS: {message}") # noqa T201
|
||||
@ -1,13 +1,15 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Progress reporting types."""
|
||||
"""Progress Logging Utilities."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -24,7 +26,7 @@ class Progress:
|
||||
"""Total number of items"""
|
||||
|
||||
completed_items: int | None = None
|
||||
"""Number of items completed""" ""
|
||||
"""Number of items completed"""
|
||||
|
||||
|
||||
ProgressHandler = Callable[[Progress], None]
|
||||
@ -35,11 +37,15 @@ class ProgressTicker:
|
||||
"""A class that emits progress reports incrementally."""
|
||||
|
||||
_callback: ProgressHandler | None
|
||||
_description: str
|
||||
_num_total: int
|
||||
_num_complete: int
|
||||
|
||||
def __init__(self, callback: ProgressHandler | None, num_total: int):
|
||||
def __init__(
|
||||
self, callback: ProgressHandler | None, num_total: int, description: str = ""
|
||||
):
|
||||
self._callback = callback
|
||||
self._description = description
|
||||
self._num_total = num_total
|
||||
self._num_complete = 0
|
||||
|
||||
@ -47,35 +53,47 @@ class ProgressTicker:
|
||||
"""Emit progress."""
|
||||
self._num_complete += num_ticks
|
||||
if self._callback is not None:
|
||||
self._callback(
|
||||
Progress(
|
||||
total_items=self._num_total, completed_items=self._num_complete
|
||||
)
|
||||
p = Progress(
|
||||
total_items=self._num_total,
|
||||
completed_items=self._num_complete,
|
||||
description=self._description,
|
||||
)
|
||||
if p.description:
|
||||
logger.info(
|
||||
"%s%s/%s", p.description, str(p.completed_items), str(p.total_items)
|
||||
)
|
||||
self._callback(p)
|
||||
|
||||
def done(self) -> None:
|
||||
"""Mark the progress as done."""
|
||||
if self._callback is not None:
|
||||
self._callback(
|
||||
Progress(total_items=self._num_total, completed_items=self._num_total)
|
||||
Progress(
|
||||
total_items=self._num_total,
|
||||
completed_items=self._num_total,
|
||||
description=self._description,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def progress_ticker(callback: ProgressHandler | None, num_total: int) -> ProgressTicker:
|
||||
def progress_ticker(
|
||||
callback: ProgressHandler | None, num_total: int, description: str = ""
|
||||
) -> ProgressTicker:
|
||||
"""Create a progress ticker."""
|
||||
return ProgressTicker(callback, num_total)
|
||||
return ProgressTicker(callback, num_total, description=description)
|
||||
|
||||
|
||||
def progress_iterable(
|
||||
iterable: Iterable[T],
|
||||
progress: ProgressHandler | None,
|
||||
num_total: int | None = None,
|
||||
description: str = "",
|
||||
) -> Iterable[T]:
|
||||
"""Wrap an iterable with a progress handler. Every time an item is yielded, the progress handler will be called with the current progress."""
|
||||
if num_total is None:
|
||||
num_total = len(list(iterable))
|
||||
|
||||
tick = ProgressTicker(progress, num_total)
|
||||
tick = ProgressTicker(progress, num_total, description=description)
|
||||
|
||||
for item in iterable:
|
||||
tick(1)
|
||||
|
||||
@ -1,165 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Rich-based progress logger for CLI use."""
|
||||
|
||||
# Print iterations progress
|
||||
import asyncio
|
||||
|
||||
from rich.console import Console, Group
|
||||
from rich.live import Live
|
||||
from rich.progress import Progress, TaskID, TimeElapsedColumn
|
||||
from rich.spinner import Spinner
|
||||
from rich.tree import Tree
|
||||
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.progress import Progress as GRProgress
|
||||
|
||||
|
||||
# https://stackoverflow.com/a/34325723
|
||||
class RichProgressLogger(ProgressLogger):
|
||||
"""A rich-based progress logger for CLI use."""
|
||||
|
||||
_console: Console
|
||||
_group: Group
|
||||
_tree: Tree
|
||||
_live: Live
|
||||
_task: TaskID | None = None
|
||||
_prefix: str
|
||||
_transient: bool
|
||||
_disposing: bool = False
|
||||
_progressbar: Progress
|
||||
_last_refresh: float = 0
|
||||
|
||||
def dispose(self) -> None:
|
||||
"""Dispose of the progress logger."""
|
||||
self._disposing = True
|
||||
self._live.stop()
|
||||
|
||||
@property
|
||||
def console(self) -> Console:
|
||||
"""Get the console."""
|
||||
return self._console
|
||||
|
||||
@property
|
||||
def group(self) -> Group:
|
||||
"""Get the group."""
|
||||
return self._group
|
||||
|
||||
@property
|
||||
def tree(self) -> Tree:
|
||||
"""Get the tree."""
|
||||
return self._tree
|
||||
|
||||
@property
|
||||
def live(self) -> Live:
|
||||
"""Get the live."""
|
||||
return self._live
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
parent: "RichProgressLogger | None" = None,
|
||||
transient: bool = True,
|
||||
) -> None:
|
||||
"""Create a new rich-based progress logger."""
|
||||
self._prefix = prefix
|
||||
|
||||
if parent is None:
|
||||
console = Console()
|
||||
group = Group(Spinner("dots", prefix), fit=True)
|
||||
tree = Tree(group)
|
||||
live = Live(
|
||||
tree, console=console, refresh_per_second=1, vertical_overflow="crop"
|
||||
)
|
||||
live.start()
|
||||
|
||||
self._console = console
|
||||
self._group = group
|
||||
self._tree = tree
|
||||
self._live = live
|
||||
self._transient = False
|
||||
else:
|
||||
self._console = parent.console
|
||||
self._group = parent.group
|
||||
progress_columns = [*Progress.get_default_columns(), TimeElapsedColumn()]
|
||||
self._progressbar = Progress(
|
||||
*progress_columns, console=self._console, transient=transient
|
||||
)
|
||||
|
||||
tree = Tree(prefix)
|
||||
tree.add(self._progressbar)
|
||||
tree.hide_root = True
|
||||
|
||||
if parent is not None:
|
||||
parent_tree = parent.tree
|
||||
parent_tree.hide_root = False
|
||||
parent_tree.add(tree)
|
||||
|
||||
self._tree = tree
|
||||
self._live = parent.live
|
||||
self._transient = transient
|
||||
|
||||
self.refresh()
|
||||
|
||||
def refresh(self) -> None:
|
||||
"""Perform a debounced refresh."""
|
||||
now = asyncio.get_event_loop().time()
|
||||
duration = now - self._last_refresh
|
||||
if duration > 0.1:
|
||||
self._last_refresh = now
|
||||
self.force_refresh()
|
||||
|
||||
def force_refresh(self) -> None:
|
||||
"""Force a refresh."""
|
||||
self.live.refresh()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the progress logger."""
|
||||
self._live.stop()
|
||||
|
||||
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
|
||||
"""Create a child progress bar."""
|
||||
return RichProgressLogger(parent=self, prefix=prefix, transient=transient)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Log an error."""
|
||||
self._console.print(f"❌ [red]{message}[/red]")
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Log a warning."""
|
||||
self._console.print(f"⚠️ [yellow]{message}[/yellow]")
|
||||
|
||||
def success(self, message: str) -> None:
|
||||
"""Log success."""
|
||||
self._console.print(f"🚀 [green]{message}[/green]")
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""Log information."""
|
||||
self._console.print(message)
|
||||
|
||||
def __call__(self, progress_update: GRProgress) -> None:
|
||||
"""Update progress."""
|
||||
if self._disposing:
|
||||
return
|
||||
progressbar = self._progressbar
|
||||
|
||||
if self._task is None:
|
||||
self._task = progressbar.add_task(self._prefix)
|
||||
|
||||
progress_description = ""
|
||||
if progress_update.description is not None:
|
||||
progress_description = f" - {progress_update.description}"
|
||||
|
||||
completed = progress_update.completed_items or progress_update.percent
|
||||
total = progress_update.total_items or 1
|
||||
progressbar.update(
|
||||
self._task,
|
||||
completed=completed,
|
||||
total=total,
|
||||
description=f"{self._prefix}{progress_description}",
|
||||
)
|
||||
if completed == total and self._transient:
|
||||
progressbar.update(self._task, visible=False)
|
||||
|
||||
self.refresh()
|
||||
153
graphrag/logger/standard_logging.py
Normal file
153
graphrag/logger/standard_logging.py
Normal file
@ -0,0 +1,153 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Standard logging configuration for the graphrag package.
|
||||
|
||||
This module provides a standardized way to configure Python's built-in
|
||||
logging system for use within the graphrag package.
|
||||
|
||||
Usage:
|
||||
# Configuration should be done once at the start of your application:
|
||||
from graphrag.logger.standard_logging import init_loggers
|
||||
init_loggers(log_file="/path/to/app.log")
|
||||
|
||||
# Then throughout your code:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__) # Use standard logging
|
||||
|
||||
# Use standard logging methods:
|
||||
logger.debug("Debug message")
|
||||
logger.info("Info message")
|
||||
logger.warning("Warning message")
|
||||
logger.error("Error message")
|
||||
logger.critical("Critical error message")
|
||||
|
||||
Notes
|
||||
-----
|
||||
The logging system is hierarchical. Loggers are organized in a tree structure,
|
||||
with the root logger named 'graphrag'. All loggers created with names starting
|
||||
with 'graphrag.' will be children of this root logger. This allows for consistent
|
||||
configuration of all graphrag-related logs throughout the application.
|
||||
|
||||
All progress logging now uses this standard logging system for consistency.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.config.enums import ReportingType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.reporting_config import ReportingConfig
|
||||
|
||||
LOG_FORMAT = "%(asctime)s.%(msecs)04d - %(levelname)s - %(name)s - %(message)s"
|
||||
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
|
||||
def init_loggers(
|
||||
config: GraphRagConfig | None = None,
|
||||
root_dir: str | None = None,
|
||||
verbose: bool = False,
|
||||
log_file: str | Path | None = None,
|
||||
) -> None:
|
||||
"""Initialize logging handlers for graphrag based on configuration.
|
||||
|
||||
This function merges the functionality of configure_logging() and create_pipeline_logger()
|
||||
to provide a unified way to set up logging for the graphrag package.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : GraphRagConfig | None, default=None
|
||||
The GraphRAG configuration. If None, defaults to file-based reporting.
|
||||
root_dir : str | None, default=None
|
||||
The root directory for file-based logging.
|
||||
verbose : bool, default=False
|
||||
Whether to enable verbose (DEBUG) logging.
|
||||
log_file : Optional[Union[str, Path]], default=None
|
||||
Path to a specific log file. If provided, takes precedence over config.
|
||||
"""
|
||||
# import BlobWorkflowLogger here to avoid circular imports
|
||||
from graphrag.logger.blob_workflow_logger import BlobWorkflowLogger
|
||||
|
||||
# extract reporting config from GraphRagConfig if provided
|
||||
reporting_config: ReportingConfig
|
||||
if log_file:
|
||||
# if log_file is provided directly, override config to use file-based logging
|
||||
log_path = Path(log_file)
|
||||
reporting_config = ReportingConfig(
|
||||
type=ReportingType.file,
|
||||
base_dir=str(log_path.parent),
|
||||
)
|
||||
elif config is not None:
|
||||
# use the reporting configuration from GraphRagConfig
|
||||
reporting_config = config.reporting
|
||||
else:
|
||||
# default to file-based logging if no config provided
|
||||
reporting_config = ReportingConfig(base_dir="logs", type=ReportingType.file)
|
||||
|
||||
logger = logging.getLogger("graphrag")
|
||||
log_level = logging.DEBUG if verbose else logging.INFO
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# clear any existing handlers to avoid duplicate logs
|
||||
if logger.hasHandlers():
|
||||
# Close file handlers properly before removing them
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, logging.FileHandler):
|
||||
handler.close()
|
||||
logger.handlers.clear()
|
||||
|
||||
# create formatter with custom format
|
||||
formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt=DATE_FORMAT)
|
||||
|
||||
init_console_logger(verbose)
|
||||
|
||||
# add more handlers based on configuration
|
||||
handler: logging.Handler
|
||||
match reporting_config.type:
|
||||
case ReportingType.file:
|
||||
if log_file:
|
||||
# use the specific log file provided
|
||||
log_file_path = Path(log_file)
|
||||
log_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
handler = logging.FileHandler(str(log_file_path), mode="a")
|
||||
else:
|
||||
# use the config-based file path
|
||||
log_dir = Path(root_dir or "") / (reporting_config.base_dir or "")
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_file_path = log_dir / "logs.txt"
|
||||
handler = logging.FileHandler(str(log_file_path), mode="a")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
case ReportingType.blob:
|
||||
handler = BlobWorkflowLogger(
|
||||
reporting_config.connection_string,
|
||||
reporting_config.container_name,
|
||||
base_dir=reporting_config.base_dir,
|
||||
storage_account_blob_url=reporting_config.storage_account_blob_url,
|
||||
)
|
||||
logger.addHandler(handler)
|
||||
case _:
|
||||
logger.error("Unknown reporting type '%s'.", reporting_config.type)
|
||||
|
||||
|
||||
def init_console_logger(verbose: bool = False) -> None:
|
||||
"""Initialize a console logger if not already present.
|
||||
|
||||
This function sets up a logger that outputs log messages to STDOUT.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
verbose : bool, default=False
|
||||
Whether to enable verbose (DEBUG) logging.
|
||||
"""
|
||||
logger = logging.getLogger("graphrag")
|
||||
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
has_console_handler = any(
|
||||
isinstance(h, logging.StreamHandler) for h in logger.handlers
|
||||
)
|
||||
if not has_console_handler:
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt=DATE_FORMAT)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
@ -1,22 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Logging types.
|
||||
|
||||
This module defines the types of loggers that can be used.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# Note: Code in this module was not included in the factory module because it negatively impacts the CLI experience.
|
||||
class LoggerType(str, Enum):
|
||||
"""The type of logger to use."""
|
||||
|
||||
RICH = "rich"
|
||||
PRINT = "print"
|
||||
NONE = "none"
|
||||
|
||||
def __str__(self):
|
||||
"""Return a string representation of the enum value."""
|
||||
return self.value
|
||||
@ -3,6 +3,8 @@
|
||||
|
||||
"""Input loading module."""
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
@ -14,7 +16,6 @@ from graphrag.index.operations.embed_text.strategies.openai import (
|
||||
run as run_embed_text,
|
||||
)
|
||||
from graphrag.index.workflows.create_base_text_units import create_base_text_units
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.prompt_tune.defaults import (
|
||||
LIMIT,
|
||||
N_SUBSET_MAX,
|
||||
@ -41,7 +42,7 @@ async def load_docs_in_chunks(
|
||||
config: GraphRagConfig,
|
||||
select_method: DocSelectionType,
|
||||
limit: int,
|
||||
logger: ProgressLogger,
|
||||
logger: logging.Logger,
|
||||
chunk_size: int,
|
||||
overlap: int,
|
||||
n_subset_max: int = N_SUBSET_MAX,
|
||||
@ -52,7 +53,7 @@ async def load_docs_in_chunks(
|
||||
config.embed_text.model_id
|
||||
)
|
||||
input_storage = create_storage_from_config(config.input.storage)
|
||||
dataset = await create_input(config.input, input_storage, logger)
|
||||
dataset = await create_input(config.input, input_storage)
|
||||
chunk_config = config.chunks
|
||||
chunks_df = create_base_text_units(
|
||||
documents=dataset,
|
||||
|
||||
@ -14,7 +14,7 @@ from graphrag.data_model.community_report import CommunityReport
|
||||
from graphrag.data_model.entity import Entity
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NO_COMMUNITY_RECORDS_WARNING: str = (
|
||||
"Warning: No community records added when building community context."
|
||||
@ -88,7 +88,7 @@ def build_community_context(
|
||||
)
|
||||
)
|
||||
if compute_community_weights:
|
||||
log.info("Computing community weights...")
|
||||
logger.debug("Computing community weights...")
|
||||
community_reports = _compute_community_weights(
|
||||
community_reports=community_reports,
|
||||
entities=entities,
|
||||
@ -177,7 +177,7 @@ def build_community_context(
|
||||
_cut_batch()
|
||||
|
||||
if len(all_context_records) == 0:
|
||||
log.warning(NO_COMMUNITY_RECORDS_WARNING)
|
||||
logger.warning(NO_COMMUNITY_RECORDS_WARNING)
|
||||
return ([], {})
|
||||
|
||||
return all_context_text, {
|
||||
|
||||
@ -18,7 +18,7 @@ from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
|
||||
from graphrag.query.context_builder.rate_relevancy import rate_relevancy
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DynamicCommunitySelection:
|
||||
@ -109,7 +109,7 @@ class DynamicCommunitySelection:
|
||||
communities_to_rate = []
|
||||
for community, result in zip(queue, gather_results, strict=True):
|
||||
rating = result["rating"]
|
||||
log.debug(
|
||||
logger.debug(
|
||||
"dynamic community selection: community %s rating %s",
|
||||
community,
|
||||
rating,
|
||||
@ -127,7 +127,7 @@ class DynamicCommunitySelection:
|
||||
if child in self.reports:
|
||||
communities_to_rate.append(child)
|
||||
else:
|
||||
log.debug(
|
||||
logger.debug(
|
||||
"dynamic community selection: cannot find community %s in reports",
|
||||
child,
|
||||
)
|
||||
@ -142,7 +142,7 @@ class DynamicCommunitySelection:
|
||||
and (str(level) in self.levels)
|
||||
and (level <= self.max_level)
|
||||
):
|
||||
log.info(
|
||||
logger.debug(
|
||||
"dynamic community selection: no relevant community "
|
||||
"reports, adding all reports at level %s to rate.",
|
||||
level,
|
||||
@ -155,7 +155,7 @@ class DynamicCommunitySelection:
|
||||
]
|
||||
end = time()
|
||||
|
||||
log.info(
|
||||
logger.debug(
|
||||
"dynamic community selection (took: %ss)\n"
|
||||
"\trating distribution %s\n"
|
||||
"\t%s out of %s community reports are relevant\n"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user