feat(graph-engine): add command to update variables at runtime (#30563)
Some checks are pending
autofix.ci / autofix (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
-LAN- 2026-01-05 16:47:34 +08:00 committed by GitHub
parent 6f8bd58e19
commit a9e2c05a10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 194 additions and 13 deletions

View File

@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
import json
from typing import TYPE_CHECKING, Any, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
@ -113,6 +113,8 @@ class RedisChannel:
return AbortCommand.model_validate(data)
if command_type == CommandType.PAUSE:
return PauseCommand.model_validate(data)
if command_type == CommandType.UPDATE_VARIABLES:
return UpdateVariablesCommand.model_validate(data)
# For other command types, use base class
return GraphEngineCommand.model_validate(data)

View File

@ -5,11 +5,12 @@ This package handles external commands sent to the engine
during execution.
"""
from .command_handlers import AbortCommandHandler, PauseCommandHandler
from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler
from .command_processor import CommandProcessor
__all__ = [
"AbortCommandHandler",
"CommandProcessor",
"PauseCommandHandler",
"UpdateVariablesCommandHandler",
]

View File

@ -4,9 +4,10 @@ from typing import final
from typing_extensions import override
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.runtime import VariablePool
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler):
reason = command.reason
pause_reason = SchedulingPause(message=reason)
execution.pause(pause_reason)
@final
class UpdateVariablesCommandHandler(CommandHandler):
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
assert isinstance(command, UpdateVariablesCommand)
for update in command.updates:
try:
variable = update.value
self._variable_pool.add(variable.selector, variable)
logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id)
except ValueError as exc:
logger.warning(
"Skipping invalid variable selector %s for workflow %s: %s",
getattr(update.value, "selector", None),
execution.workflow_id,
exc,
)

View File

@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine
instance to control its execution flow.
"""
from enum import StrEnum
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
from core.variables.variables import VariableUnion
class CommandType(StrEnum):
"""Types of commands that can be sent to GraphEngine."""
ABORT = "abort"
PAUSE = "pause"
ABORT = auto()
PAUSE = auto()
UPDATE_VARIABLES = auto()
class GraphEngineCommand(BaseModel):
@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand):
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
reason: str = Field(default="unknown reason", description="reason for pause")
class VariableUpdate(BaseModel):
"""Represents a single variable update instruction."""
value: VariableUnion = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand):
"""Command to update a group of variables in the variable pool."""
command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command")
updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates")

View File

@ -30,8 +30,13 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
from core.workflow.runtime.graph_runtime_state import GraphProtocol
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
from .entities.commands import AbortCommand, PauseCommand
from .command_processing import (
AbortCommandHandler,
CommandProcessor,
PauseCommandHandler,
UpdateVariablesCommandHandler,
)
from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
from .error_handler import ErrorHandler
from .event_management import EventHandler, EventManager
from .graph_state_manager import GraphStateManager
@ -140,6 +145,9 @@ class GraphEngine:
pause_handler = PauseCommandHandler()
self._command_processor.register_handler(PauseCommand, pause_handler)
update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool)
self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler)
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[GraphEngineLayer] = []

View File

@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel.
This module provides a simplified interface for controlling workflow executions
using the new Redis command channel, without requiring user permission checks.
Supports stop, pause, and resume operations.
"""
import logging
from collections.abc import Sequence
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from core.workflow.graph_engine.entities.commands import (
AbortCommand,
GraphEngineCommand,
PauseCommand,
UpdateVariablesCommand,
VariableUpdate,
)
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@ -23,7 +29,6 @@ class GraphEngineManager:
This class provides a simple interface for controlling workflow executions
by sending commands through Redis channels, without user validation.
Supports stop and pause operations.
"""
@staticmethod
@ -45,6 +50,16 @@ class GraphEngineManager:
pause_command = PauseCommand(reason=reason or "User requested pause")
GraphEngineManager._send_command(task_id, pause_command)
@staticmethod
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
"""Send a command to update variables in a running workflow."""
if not updates:
return
update_command = UpdateVariablesCommand(updates=updates)
GraphEngineManager._send_command(task_id, update_command)
@staticmethod
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
"""Send a command to the workflow-specific Redis channel."""

View File

@ -3,8 +3,15 @@
import json
from unittest.mock import MagicMock
from core.variables import IntegerVariable, StringVariable
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, GraphEngineCommand
from core.workflow.graph_engine.entities.commands import (
AbortCommand,
CommandType,
GraphEngineCommand,
UpdateVariablesCommand,
VariableUpdate,
)
class TestRedisChannel:
@ -148,6 +155,43 @@ class TestRedisChannel:
assert commands[0].command_type == CommandType.ABORT
assert isinstance(commands[1], AbortCommand)
def test_fetch_commands_with_update_variables_command(self):
"""Test fetching update variables command from Redis."""
mock_redis = MagicMock()
pending_pipe = MagicMock()
fetch_pipe = MagicMock()
pending_context = MagicMock()
fetch_context = MagicMock()
pending_context.__enter__.return_value = pending_pipe
pending_context.__exit__.return_value = None
fetch_context.__enter__.return_value = fetch_pipe
fetch_context.__exit__.return_value = None
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
update_command = UpdateVariablesCommand(
updates=[
VariableUpdate(
value=StringVariable(name="foo", value="bar", selector=["node1", "foo"]),
),
VariableUpdate(
value=IntegerVariable(name="baz", value=123, selector=["node2", "baz"]),
),
]
)
command_json = json.dumps(update_command.model_dump())
pending_pipe.execute.return_value = [b"1", 1]
fetch_pipe.execute.return_value = [[command_json.encode()], 1]
channel = RedisChannel(mock_redis, "test:key")
commands = channel.fetch_commands()
assert len(commands) == 1
assert isinstance(commands[0], UpdateVariablesCommand)
assert isinstance(commands[0].updates[0].value, StringVariable)
assert list(commands[0].updates[0].value.selector) == ["node1", "foo"]
assert commands[0].updates[0].value.value == "bar"
def test_fetch_commands_skips_invalid_json(self):
"""Test that invalid JSON commands are skipped."""
mock_redis = MagicMock()

View File

@ -4,12 +4,19 @@ import time
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import IntegerVariable, StringVariable
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand
from core.workflow.graph_engine.entities.commands import (
AbortCommand,
CommandType,
PauseCommand,
UpdateVariablesCommand,
VariableUpdate,
)
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
@ -180,3 +187,67 @@ def test_pause_command():
graph_execution = engine.graph_runtime_state.graph_execution
assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")]
def test_update_variables_command_updates_pool():
"""Test that GraphEngine updates variable pool via update variables command."""
shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
shared_runtime_state.variable_pool.add(("node1", "foo"), "old value")
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start"
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=shared_runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
command_channel = InMemoryChannel()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=shared_runtime_state,
command_channel=command_channel,
)
update_command = UpdateVariablesCommand(
updates=[
VariableUpdate(
value=StringVariable(name="foo", value="new value", selector=["node1", "foo"]),
),
VariableUpdate(
value=IntegerVariable(name="bar", value=123, selector=["node2", "bar"]),
),
]
)
command_channel.send_command(update_command)
list(engine.run())
updated_existing = shared_runtime_state.variable_pool.get(["node1", "foo"])
added_new = shared_runtime_state.variable_pool.get(["node2", "bar"])
assert updated_existing is not None
assert updated_existing.value == "new value"
assert added_new is not None
assert added_new.value == 123