From 57132052102d777d71ec69bc20be288fc01eb02a Mon Sep 17 00:00:00 2001 From: Alonso Guevara Date: Fri, 8 Aug 2025 16:59:24 -0600 Subject: [PATCH] Feat/additional context (#2021) * Users/snehitgajjar/add optional api param for pipeline state (#2019) * Add support for additional context for PipelineState * Clean up * Clean up * Clean up * Nit --------- Co-authored-by: Snehit Gajjar * Semver * Update graphrag/api/index.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Remove additional_context from serialization --------- Co-authored-by: Snehit Gajjar Co-authored-by: Snehit Gajjar Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../minor-20250807204918927024.json | 4 ++++ graphrag/api/index.py | 5 +++++ graphrag/index/run/run_pipeline.py | 19 ++++++++++++++++--- 3 files changed, 25 insertions(+), 3 deletions(-) create mode 100644 .semversioner/next-release/minor-20250807204918927024.json diff --git a/.semversioner/next-release/minor-20250807204918927024.json b/.semversioner/next-release/minor-20250807204918927024.json new file mode 100644 index 00000000..09843fa2 --- /dev/null +++ b/.semversioner/next-release/minor-20250807204918927024.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Add additional context variable to build index signature for custom parameter bag" +} diff --git a/graphrag/api/index.py b/graphrag/api/index.py index c7f7f078..a9141732 100644 --- a/graphrag/api/index.py +++ b/graphrag/api/index.py @@ -9,6 +9,7 @@ Backwards compatibility is not guaranteed at this time. """ import logging +from typing import Any from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks @@ -30,6 +31,7 @@ async def build_index( is_update_run: bool = False, memory_profile: bool = False, callbacks: list[WorkflowCallbacks] | None = None, + additional_context: dict[str, Any] | None = None, ) -> list[PipelineRunResult]: """Run the pipeline with the given configuration. @@ -43,6 +45,8 @@ async def build_index( Whether to enable memory profiling. callbacks : list[WorkflowCallbacks] | None default=None A list of callbacks to register. + additional_context : dict[str, Any] | None default=None + Additional context to pass to the pipeline run. This can be accessed in the pipeline state under the 'additional_context' key. Returns ------- @@ -73,6 +77,7 @@ async def build_index( config, callbacks=workflow_callbacks, is_update_run=is_update_run, + additional_context=additional_context, ): outputs.append(output) if output.errors and len(output.errors) > 0: diff --git a/graphrag/index/run/run_pipeline.py b/graphrag/index/run/run_pipeline.py index a628b5ad..d2047058 100644 --- a/graphrag/index/run/run_pipeline.py +++ b/graphrag/index/run/run_pipeline.py @@ -9,6 +9,7 @@ import re import time from collections.abc import AsyncIterable from dataclasses import asdict +from typing import Any from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig @@ -28,6 +29,7 @@ async def run_pipeline( config: GraphRagConfig, callbacks: WorkflowCallbacks, is_update_run: bool = False, + additional_context: dict[str, Any] | None = None, ) -> AsyncIterable[PipelineRunResult]: """Run all workflows using a simplified pipeline.""" root_dir = config.root_dir @@ -40,6 +42,9 @@ async def run_pipeline( state_json = await output_storage.get("context.json") state = json.loads(state_json) if state_json else {} + if additional_context: + state.setdefault("additional_context", {}).update(additional_context) + if is_update_run: logger.info("Running incremental indexing.") @@ -126,9 +131,17 @@ async def _dump_json(context: PipelineRunContext) -> None: await context.output_storage.set( "stats.json", json.dumps(asdict(context.stats), indent=4, ensure_ascii=False) ) - await context.output_storage.set( - "context.json", json.dumps(context.state, indent=4, ensure_ascii=False) - ) + # Dump context state, excluding additional_context + temp_context = context.state.pop( + "additional_context", None + ) # Remove reference only, as object size is uncertain + try: + state_blob = json.dumps(context.state, indent=4, ensure_ascii=False) + finally: + if temp_context: + context.state["additional_context"] = temp_context + + await context.output_storage.set("context.json", state_blob) async def _copy_previous_output(