Feat/additional context (#2021)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled

* Users/snehitgajjar/add optional api param for pipeline state (#2019)

* Add support for additional context for PipelineState

* Clean up

* Clean up

* Clean up

* Nit

---------

Co-authored-by: Snehit Gajjar <snehitgajjar@microsoft.com>

* Semver

* Update graphrag/api/index.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Remove additional_context from serialization

---------

Co-authored-by: Snehit Gajjar <snehit.gajjar@gmail.com>
Co-authored-by: Snehit Gajjar <snehitgajjar@microsoft.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Alonso Guevara 2025-08-08 16:59:24 -06:00 committed by GitHub
parent 1da1380615
commit 5713205210
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 25 additions and 3 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Add additional context variable to build index signature for custom parameter bag"
}

View File

@ -9,6 +9,7 @@ Backwards compatibility is not guaranteed at this time.
""" """
import logging import logging
from typing import Any
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
@ -30,6 +31,7 @@ async def build_index(
is_update_run: bool = False, is_update_run: bool = False,
memory_profile: bool = False, memory_profile: bool = False,
callbacks: list[WorkflowCallbacks] | None = None, callbacks: list[WorkflowCallbacks] | None = None,
additional_context: dict[str, Any] | None = None,
) -> list[PipelineRunResult]: ) -> list[PipelineRunResult]:
"""Run the pipeline with the given configuration. """Run the pipeline with the given configuration.
@ -43,6 +45,8 @@ async def build_index(
Whether to enable memory profiling. Whether to enable memory profiling.
callbacks : list[WorkflowCallbacks] | None default=None callbacks : list[WorkflowCallbacks] | None default=None
A list of callbacks to register. A list of callbacks to register.
additional_context : dict[str, Any] | None default=None
Additional context to pass to the pipeline run. This can be accessed in the pipeline state under the 'additional_context' key.
Returns Returns
------- -------
@ -73,6 +77,7 @@ async def build_index(
config, config,
callbacks=workflow_callbacks, callbacks=workflow_callbacks,
is_update_run=is_update_run, is_update_run=is_update_run,
additional_context=additional_context,
): ):
outputs.append(output) outputs.append(output)
if output.errors and len(output.errors) > 0: if output.errors and len(output.errors) > 0:

View File

@ -9,6 +9,7 @@ import re
import time import time
from collections.abc import AsyncIterable from collections.abc import AsyncIterable
from dataclasses import asdict from dataclasses import asdict
from typing import Any
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.config.models.graph_rag_config import GraphRagConfig
@ -28,6 +29,7 @@ async def run_pipeline(
config: GraphRagConfig, config: GraphRagConfig,
callbacks: WorkflowCallbacks, callbacks: WorkflowCallbacks,
is_update_run: bool = False, is_update_run: bool = False,
additional_context: dict[str, Any] | None = None,
) -> AsyncIterable[PipelineRunResult]: ) -> AsyncIterable[PipelineRunResult]:
"""Run all workflows using a simplified pipeline.""" """Run all workflows using a simplified pipeline."""
root_dir = config.root_dir root_dir = config.root_dir
@ -40,6 +42,9 @@ async def run_pipeline(
state_json = await output_storage.get("context.json") state_json = await output_storage.get("context.json")
state = json.loads(state_json) if state_json else {} state = json.loads(state_json) if state_json else {}
if additional_context:
state.setdefault("additional_context", {}).update(additional_context)
if is_update_run: if is_update_run:
logger.info("Running incremental indexing.") logger.info("Running incremental indexing.")
@ -126,9 +131,17 @@ async def _dump_json(context: PipelineRunContext) -> None:
await context.output_storage.set( await context.output_storage.set(
"stats.json", json.dumps(asdict(context.stats), indent=4, ensure_ascii=False) "stats.json", json.dumps(asdict(context.stats), indent=4, ensure_ascii=False)
) )
await context.output_storage.set( # Dump context state, excluding additional_context
"context.json", json.dumps(context.state, indent=4, ensure_ascii=False) temp_context = context.state.pop(
) "additional_context", None
) # Remove reference only, as object size is uncertain
try:
state_blob = json.dumps(context.state, indent=4, ensure_ascii=False)
finally:
if temp_context:
context.state["additional_context"] = temp_context
await context.output_storage.set("context.json", state_blob)
async def _copy_previous_output( async def _copy_previous_output(