mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
feat(graph-engine): add command to update variables at runtime (#30563)
Some checks are pending
autofix.ci / autofix (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Some checks are pending
autofix.ci / autofix (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
6f8bd58e19
commit
a9e2c05a10
@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
@ -113,6 +113,8 @@ class RedisChannel:
|
||||
return AbortCommand.model_validate(data)
|
||||
if command_type == CommandType.PAUSE:
|
||||
return PauseCommand.model_validate(data)
|
||||
if command_type == CommandType.UPDATE_VARIABLES:
|
||||
return UpdateVariablesCommand.model_validate(data)
|
||||
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand.model_validate(data)
|
||||
|
||||
@ -5,11 +5,12 @@ This package handles external commands sent to the engine
|
||||
during execution.
|
||||
"""
|
||||
|
||||
from .command_handlers import AbortCommandHandler, PauseCommandHandler
|
||||
from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler
|
||||
from .command_processor import CommandProcessor
|
||||
|
||||
__all__ = [
|
||||
"AbortCommandHandler",
|
||||
"CommandProcessor",
|
||||
"PauseCommandHandler",
|
||||
"UpdateVariablesCommandHandler",
|
||||
]
|
||||
|
||||
@ -4,9 +4,10 @@ from typing import final
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
|
||||
from .command_processor import CommandHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler):
|
||||
reason = command.reason
|
||||
pause_reason = SchedulingPause(message=reason)
|
||||
execution.pause(pause_reason)
|
||||
|
||||
|
||||
@final
|
||||
class UpdateVariablesCommandHandler(CommandHandler):
|
||||
def __init__(self, variable_pool: VariablePool) -> None:
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
assert isinstance(command, UpdateVariablesCommand)
|
||||
for update in command.updates:
|
||||
try:
|
||||
variable = update.value
|
||||
self._variable_pool.add(variable.selector, variable)
|
||||
logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id)
|
||||
except ValueError as exc:
|
||||
logger.warning(
|
||||
"Skipping invalid variable selector %s for workflow %s: %s",
|
||||
getattr(update.value, "selector", None),
|
||||
execution.workflow_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine
|
||||
instance to control its execution flow.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.variables.variables import VariableUnion
|
||||
|
||||
|
||||
class CommandType(StrEnum):
|
||||
"""Types of commands that can be sent to GraphEngine."""
|
||||
|
||||
ABORT = "abort"
|
||||
PAUSE = "pause"
|
||||
ABORT = auto()
|
||||
PAUSE = auto()
|
||||
UPDATE_VARIABLES = auto()
|
||||
|
||||
|
||||
class GraphEngineCommand(BaseModel):
|
||||
@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand):
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
|
||||
reason: str = Field(default="unknown reason", description="reason for pause")
|
||||
|
||||
|
||||
class VariableUpdate(BaseModel):
|
||||
"""Represents a single variable update instruction."""
|
||||
|
||||
value: VariableUnion = Field(description="New variable value")
|
||||
|
||||
|
||||
class UpdateVariablesCommand(GraphEngineCommand):
|
||||
"""Command to update a group of variables in the variable pool."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command")
|
||||
updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates")
|
||||
|
||||
@ -30,8 +30,13 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr
|
||||
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
|
||||
from core.workflow.runtime.graph_runtime_state import GraphProtocol
|
||||
|
||||
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
|
||||
from .entities.commands import AbortCommand, PauseCommand
|
||||
from .command_processing import (
|
||||
AbortCommandHandler,
|
||||
CommandProcessor,
|
||||
PauseCommandHandler,
|
||||
UpdateVariablesCommandHandler,
|
||||
)
|
||||
from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
|
||||
from .error_handler import ErrorHandler
|
||||
from .event_management import EventHandler, EventManager
|
||||
from .graph_state_manager import GraphStateManager
|
||||
@ -140,6 +145,9 @@ class GraphEngine:
|
||||
pause_handler = PauseCommandHandler()
|
||||
self._command_processor.register_handler(PauseCommand, pause_handler)
|
||||
|
||||
update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool)
|
||||
self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler)
|
||||
|
||||
# === Extensibility ===
|
||||
# Layers allow plugins to extend engine functionality
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
|
||||
@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel.
|
||||
|
||||
This module provides a simplified interface for controlling workflow executions
|
||||
using the new Redis command channel, without requiring user permission checks.
|
||||
Supports stop, pause, and resume operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from core.workflow.graph_engine.entities.commands import (
|
||||
AbortCommand,
|
||||
GraphEngineCommand,
|
||||
PauseCommand,
|
||||
UpdateVariablesCommand,
|
||||
VariableUpdate,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -23,7 +29,6 @@ class GraphEngineManager:
|
||||
|
||||
This class provides a simple interface for controlling workflow executions
|
||||
by sending commands through Redis channels, without user validation.
|
||||
Supports stop and pause operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@ -45,6 +50,16 @@ class GraphEngineManager:
|
||||
pause_command = PauseCommand(reason=reason or "User requested pause")
|
||||
GraphEngineManager._send_command(task_id, pause_command)
|
||||
|
||||
@staticmethod
|
||||
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
|
||||
"""Send a command to update variables in a running workflow."""
|
||||
|
||||
if not updates:
|
||||
return
|
||||
|
||||
update_command = UpdateVariablesCommand(updates=updates)
|
||||
GraphEngineManager._send_command(task_id, update_command)
|
||||
|
||||
@staticmethod
|
||||
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
|
||||
"""Send a command to the workflow-specific Redis channel."""
|
||||
|
||||
@ -3,8 +3,15 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.variables import IntegerVariable, StringVariable
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, GraphEngineCommand
|
||||
from core.workflow.graph_engine.entities.commands import (
|
||||
AbortCommand,
|
||||
CommandType,
|
||||
GraphEngineCommand,
|
||||
UpdateVariablesCommand,
|
||||
VariableUpdate,
|
||||
)
|
||||
|
||||
|
||||
class TestRedisChannel:
|
||||
@ -148,6 +155,43 @@ class TestRedisChannel:
|
||||
assert commands[0].command_type == CommandType.ABORT
|
||||
assert isinstance(commands[1], AbortCommand)
|
||||
|
||||
def test_fetch_commands_with_update_variables_command(self):
|
||||
"""Test fetching update variables command from Redis."""
|
||||
mock_redis = MagicMock()
|
||||
pending_pipe = MagicMock()
|
||||
fetch_pipe = MagicMock()
|
||||
pending_context = MagicMock()
|
||||
fetch_context = MagicMock()
|
||||
pending_context.__enter__.return_value = pending_pipe
|
||||
pending_context.__exit__.return_value = None
|
||||
fetch_context.__enter__.return_value = fetch_pipe
|
||||
fetch_context.__exit__.return_value = None
|
||||
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
|
||||
|
||||
update_command = UpdateVariablesCommand(
|
||||
updates=[
|
||||
VariableUpdate(
|
||||
value=StringVariable(name="foo", value="bar", selector=["node1", "foo"]),
|
||||
),
|
||||
VariableUpdate(
|
||||
value=IntegerVariable(name="baz", value=123, selector=["node2", "baz"]),
|
||||
),
|
||||
]
|
||||
)
|
||||
command_json = json.dumps(update_command.model_dump())
|
||||
|
||||
pending_pipe.execute.return_value = [b"1", 1]
|
||||
fetch_pipe.execute.return_value = [[command_json.encode()], 1]
|
||||
|
||||
channel = RedisChannel(mock_redis, "test:key")
|
||||
commands = channel.fetch_commands()
|
||||
|
||||
assert len(commands) == 1
|
||||
assert isinstance(commands[0], UpdateVariablesCommand)
|
||||
assert isinstance(commands[0].updates[0].value, StringVariable)
|
||||
assert list(commands[0].updates[0].value.selector) == ["node1", "foo"]
|
||||
assert commands[0].updates[0].value.value == "bar"
|
||||
|
||||
def test_fetch_commands_skips_invalid_json(self):
|
||||
"""Test that invalid JSON commands are skipped."""
|
||||
mock_redis = MagicMock()
|
||||
|
||||
@ -4,12 +4,19 @@ import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import IntegerVariable, StringVariable
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand
|
||||
from core.workflow.graph_engine.entities.commands import (
|
||||
AbortCommand,
|
||||
CommandType,
|
||||
PauseCommand,
|
||||
UpdateVariablesCommand,
|
||||
VariableUpdate,
|
||||
)
|
||||
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
@ -180,3 +187,67 @@ def test_pause_command():
|
||||
|
||||
graph_execution = engine.graph_runtime_state.graph_execution
|
||||
assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")]
|
||||
|
||||
|
||||
def test_update_variables_command_updates_pool():
|
||||
"""Test that GraphEngine updates variable pool via update variables command."""
|
||||
|
||||
shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
shared_runtime_state.variable_pool.add(("node1", "foo"), "old value")
|
||||
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start"
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=shared_runtime_state,
|
||||
)
|
||||
mock_graph.nodes["start"] = start_node
|
||||
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
command_channel = InMemoryChannel()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=shared_runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
update_command = UpdateVariablesCommand(
|
||||
updates=[
|
||||
VariableUpdate(
|
||||
value=StringVariable(name="foo", value="new value", selector=["node1", "foo"]),
|
||||
),
|
||||
VariableUpdate(
|
||||
value=IntegerVariable(name="bar", value=123, selector=["node2", "bar"]),
|
||||
),
|
||||
]
|
||||
)
|
||||
command_channel.send_command(update_command)
|
||||
|
||||
list(engine.run())
|
||||
|
||||
updated_existing = shared_runtime_state.variable_pool.get(["node1", "foo"])
|
||||
added_new = shared_runtime_state.variable_pool.get(["node2", "bar"])
|
||||
|
||||
assert updated_existing is not None
|
||||
assert updated_existing.value == "new value"
|
||||
assert added_new is not None
|
||||
assert added_new.value == 123
|
||||
|
||||
Loading…
Reference in New Issue
Block a user