feat(api): enhance workflow validation and structure checks

- Added a new validation class to ensure that trigger nodes do not coexist with UserInput (start) nodes in the workflow graph.
- Implemented a method in WorkflowService to validate the graph structure before persisting workflows, leveraging the new validation logic.
- Updated unit tests to cover the new validation scenarios and ensure proper error propagation.
This commit is contained in:
Harry 2025-11-11 14:52:04 +08:00
parent 7484a020e1
commit aad31bb703
5 changed files with 132 additions and 3 deletions

View File

@ -43,6 +43,9 @@ class InvokeFrom(StrEnum):
# the workflow (or chatflow) edit page.
DEBUGGER = "debugger"
PUBLISHED = "published"
# VALIDATION indicates that this invocation is from validation.
VALIDATION = "validation"
@classmethod
def value_of(cls, value: str):

View File

@ -114,9 +114,45 @@ class GraphValidator:
raise GraphValidationError(issues)
@dataclass(frozen=True, slots=True)
class _TriggerStartExclusivityValidator:
"""Ensures trigger nodes do not coexist with UserInput (start) nodes."""
conflict_code: str = "TRIGGER_START_NODE_CONFLICT"
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
start_node_id: str | None = None
trigger_node_ids: list[str] = []
for node in graph.nodes.values():
node_type = getattr(node, "node_type", None)
if not isinstance(node_type, NodeType):
continue
if node_type == NodeType.START:
start_node_id = node.id
elif node_type.is_trigger_node:
trigger_node_ids.append(node.id)
if start_node_id and trigger_node_ids:
trigger_list = ", ".join(trigger_node_ids)
return [
GraphValidationIssue(
code=self.conflict_code,
message=(
f"UserInput (start) node '{start_node_id}' cannot coexist with trigger nodes: {trigger_list}."
),
node_id=start_node_id,
)
]
return []
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
_EdgeEndpointValidator(),
_RootNodeValidator(),
_TriggerStartExclusivityValidator(),
)

View File

@ -10,20 +10,22 @@ from sqlalchemy.orm import Session, sessionmaker
from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.variables import VariableUnion
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool, WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
@ -32,6 +34,7 @@ from extensions.ext_storage import storage
from factories.file_factory import build_from_mapping, build_from_mappings
from libs.datetime_utils import naive_utc_now
from models import Account
from models.enums import UserFrom
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
@ -211,6 +214,9 @@ class WorkflowService:
# validate features structure
self.validate_features_structure(app_model=app_model, features=features)
# validate graph structure
self.validate_graph_structure(user_id=account.id, app_model=app_model, graph=graph)
# create draft workflow if not found
if not workflow:
workflow = Workflow(
@ -267,6 +273,9 @@ class WorkflowService:
if FeatureService.get_system_features().plugin_manager.enabled:
self._validate_workflow_credentials(draft_workflow)
# validate graph structure
self.validate_graph_structure(user_id=account.id, app_model=app_model, graph=draft_workflow.graph_dict)
# create new workflow
workflow = Workflow.new(
tenant_id=app_model.tenant_id,
@ -896,6 +905,36 @@ class WorkflowService:
return new_app
def validate_graph_structure(self, user_id: str, app_model: App, graph: Mapping[str, Any]) -> None:
"""
Validate workflow graph structure by instantiating the Graph object.
This leverages the built-in graph validators (including trigger/UserInput exclusivity)
and raises any structural errors before persisting the workflow.
"""
Graph.init(
graph_config=graph,
# TODO(Mairuis): Add root node id
root_node_id=None,
node_factory=DifyNodeFactory(
graph_init_params=GraphInitParams(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
workflow_id=app_model.workflow_id,
graph_config=graph,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.VALIDATION,
call_depth=0,
),
graph_runtime_state=GraphRuntimeState(
variable_pool=VariablePool(),
start_at=time.perf_counter(),
),
),
)
def validate_features_structure(self, app_model: App, features: dict):
if app_model.mode == AppMode.ADVANCED_CHAT:
return AdvancedChatAppConfigManager.config_validate(

View File

@ -64,6 +64,15 @@ class _TestNode(Node):
)
self.data = dict(data)
node_type_value = data.get("type")
if isinstance(node_type_value, NodeType):
self.node_type = node_type_value
elif isinstance(node_type_value, str):
try:
self.node_type = NodeType(node_type_value)
except ValueError:
pass
def _run(self):
raise NotImplementedError
@ -179,3 +188,22 @@ def test_graph_promotes_fail_branch_nodes_to_branch_execution_type(
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH
def test_graph_validation_blocks_start_and_trigger_coexistence(
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
) -> None:
node_factory, graph_config = graph_init_dependencies
graph_config["nodes"] = [
{"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}},
{
"id": "trigger",
"data": {"type": NodeType.TRIGGER_WEBHOOK, "title": "Webhook", "execution_type": NodeExecutionType.ROOT},
},
]
graph_config["edges"] = []
with pytest.raises(GraphValidationError) as exc_info:
Graph.init(graph_config=graph_config, node_factory=node_factory)
assert any(issue.code == "TRIGGER_START_NODE_CONFLICT" for issue in exc_info.value.issues)

View File

@ -1,7 +1,8 @@
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import pytest
from core.workflow.graph.validation import GraphValidationError, GraphValidationIssue
from models.model import App
from models.workflow import Workflow
from services.workflow_service import WorkflowService
@ -161,3 +162,25 @@ class TestWorkflowService:
assert workflows == []
assert has_more is False
mock_session.scalars.assert_called_once()
def test_validate_graph_structure_invokes_graph_init(self, workflow_service, mock_app):
graph = {"nodes": [], "edges": []}
with patch("services.workflow_service.Graph.init") as mock_graph_init:
workflow_service.validate_graph_structure(mock_app, graph)
mock_graph_init.assert_called_once()
assert mock_graph_init.call_args.kwargs["graph_config"] is graph
assert "node_factory" in mock_graph_init.call_args.kwargs
def test_validate_graph_structure_propagates_graph_errors(self, workflow_service, mock_app):
graph = {"nodes": [], "edges": []}
issue = GraphValidationIssue(code="ERR", message="invalid")
with patch("services.workflow_service.Graph.init", side_effect=GraphValidationError([issue])):
with pytest.raises(GraphValidationError):
workflow_service.validate_graph_structure(mock_app, graph)
def test_validate_graph_structure_requires_nodes_and_edges(self, workflow_service, mock_app):
with pytest.raises(ValueError, match="must include 'nodes' and 'edges'"):
workflow_service.validate_graph_structure(mock_app, {"nodes": []})