mirror of
https://github.com/microsoft/graphrag.git
synced 2026-02-18 16:55:50 +08:00
Some checks failed
Python Build and Type Check / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python Build and Type Check / python-ci (ubuntu-latest, 3.12) (push) Has been cancelled
Python Build and Type Check / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Build and Type Check / python-ci (windows-latest, 3.12) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.12) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.12) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.12) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.12) (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.12) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.12) (push) Has been cancelled
Python Unit Tests / python-ci (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unit Tests / python-ci (windows-latest, 3.12) (push) Has been cancelled
* Add load_config to graphrag-common package.
55 lines
1.8 KiB
Python
55 lines
1.8 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
|
|
"""Tests for pipeline state passthrough."""
|
|
|
|
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
|
from graphrag.index.run.utils import create_run_context
|
|
from graphrag.index.typing.context import PipelineRunContext
|
|
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
|
from graphrag.index.workflows.factory import PipelineFactory
|
|
|
|
from tests.verbs.util import DEFAULT_MODEL_CONFIG
|
|
|
|
|
|
async def run_workflow_1( # noqa: RUF029
|
|
_config: GraphRagConfig, context: PipelineRunContext
|
|
):
|
|
context.state["count"] = 1
|
|
return WorkflowFunctionOutput(result=None)
|
|
|
|
|
|
async def run_workflow_2( # noqa: RUF029
|
|
_config: GraphRagConfig, context: PipelineRunContext
|
|
):
|
|
context.state["count"] += 1
|
|
return WorkflowFunctionOutput(result=None)
|
|
|
|
|
|
async def test_pipeline_state():
|
|
# checks that we can update the arbitrary state block within the pipeline run context
|
|
PipelineFactory.register("workflow_1", run_workflow_1)
|
|
PipelineFactory.register("workflow_2", run_workflow_2)
|
|
|
|
config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore
|
|
config.workflows = ["workflow_1", "workflow_2"]
|
|
context = create_run_context()
|
|
|
|
for _, fn in PipelineFactory.create_pipeline(config).run():
|
|
await fn(config, context)
|
|
|
|
assert context.state["count"] == 2
|
|
|
|
|
|
async def test_pipeline_existing_state():
|
|
PipelineFactory.register("workflow_2", run_workflow_2)
|
|
|
|
config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore
|
|
config.workflows = ["workflow_2"]
|
|
context = create_run_context(state={"count": 4})
|
|
|
|
for _, fn in PipelineFactory.create_pipeline(config).run():
|
|
await fn(config, context)
|
|
|
|
assert context.state["count"] == 5
|