diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 909877e31c..6ee1295a78 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -43,6 +43,9 @@ class InvokeFrom(StrEnum): # the workflow (or chatflow) edit page. DEBUGGER = "debugger" PUBLISHED = "published" + + # VALIDATION indicates that this invocation is from validation. + VALIDATION = "validation" @classmethod def value_of(cls, value: str): diff --git a/api/core/workflow/graph/validation.py b/api/core/workflow/graph/validation.py index 87aa7db2e4..41b4fdfa60 100644 --- a/api/core/workflow/graph/validation.py +++ b/api/core/workflow/graph/validation.py @@ -114,9 +114,45 @@ class GraphValidator: raise GraphValidationError(issues) +@dataclass(frozen=True, slots=True) +class _TriggerStartExclusivityValidator: + """Ensures trigger nodes do not coexist with UserInput (start) nodes.""" + + conflict_code: str = "TRIGGER_START_NODE_CONFLICT" + + def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: + start_node_id: str | None = None + trigger_node_ids: list[str] = [] + + for node in graph.nodes.values(): + node_type = getattr(node, "node_type", None) + if not isinstance(node_type, NodeType): + continue + + if node_type == NodeType.START: + start_node_id = node.id + elif node_type.is_trigger_node: + trigger_node_ids.append(node.id) + + if start_node_id and trigger_node_ids: + trigger_list = ", ".join(trigger_node_ids) + return [ + GraphValidationIssue( + code=self.conflict_code, + message=( + f"UserInput (start) node '{start_node_id}' cannot coexist with trigger nodes: {trigger_list}." + ), + node_id=start_node_id, + ) + ] + + return [] + + _DEFAULT_RULES: tuple[GraphValidationRule, ...] = ( _EdgeEndpointValidator(), _RootNodeValidator(), + _TriggerStartExclusivityValidator(), ) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index f36f2fea4e..8743945409 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -10,20 +10,22 @@ from sqlalchemy.orm import Session, sessionmaker from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.entities.app_invoke_entities import InvokeFrom from core.file import File from core.repositories import DifyCoreRepositoryFactory from core.variables import Variable from core.variables.variables import VariableUnion -from core.workflow.entities import WorkflowNodeExecution +from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool, WorkflowNodeExecution from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.errors import WorkflowNodeRunFailedError +from core.workflow.graph.graph import Graph from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent from core.workflow.node_events import NodeRunResult from core.workflow.nodes import NodeType from core.workflow.nodes.base.node import Node +from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated @@ -32,6 +34,7 @@ from extensions.ext_storage import storage from factories.file_factory import build_from_mapping, build_from_mappings from libs.datetime_utils import naive_utc_now from models import Account +from models.enums import UserFrom from models.model import App, AppMode from models.tools import WorkflowToolProvider from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType @@ -211,6 +214,9 @@ class WorkflowService: # validate features structure self.validate_features_structure(app_model=app_model, features=features) + # validate graph structure + self.validate_graph_structure(user_id=account.id, app_model=app_model, graph=graph) + # create draft workflow if not found if not workflow: workflow = Workflow( @@ -267,6 +273,9 @@ class WorkflowService: if FeatureService.get_system_features().plugin_manager.enabled: self._validate_workflow_credentials(draft_workflow) + # validate graph structure + self.validate_graph_structure(user_id=account.id, app_model=app_model, graph=draft_workflow.graph_dict) + # create new workflow workflow = Workflow.new( tenant_id=app_model.tenant_id, @@ -896,6 +905,36 @@ class WorkflowService: return new_app + def validate_graph_structure(self, user_id: str, app_model: App, graph: Mapping[str, Any]) -> None: + """ + Validate workflow graph structure by instantiating the Graph object. + + This leverages the built-in graph validators (including trigger/UserInput exclusivity) + and raises any structural errors before persisting the workflow. + """ + + Graph.init( + graph_config=graph, + # TODO(Mairuis): Add root node id + root_node_id=None, + node_factory=DifyNodeFactory( + graph_init_params=GraphInitParams( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=app_model.workflow_id, + graph_config=graph, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.VALIDATION, + call_depth=0, + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=VariablePool(), + start_at=time.perf_counter(), + ), + ), + ) + def validate_features_structure(self, app_model: App, features: dict): if app_model.mode == AppMode.ADVANCED_CHAT: return AdvancedChatAppConfigManager.config_validate( diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py index b55d4998c4..c55c40c5b4 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -64,6 +64,15 @@ class _TestNode(Node): ) self.data = dict(data) + node_type_value = data.get("type") + if isinstance(node_type_value, NodeType): + self.node_type = node_type_value + elif isinstance(node_type_value, str): + try: + self.node_type = NodeType(node_type_value) + except ValueError: + pass + def _run(self): raise NotImplementedError @@ -179,3 +188,22 @@ def test_graph_promotes_fail_branch_nodes_to_branch_execution_type( graph = Graph.init(graph_config=graph_config, node_factory=node_factory) assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH + + +def test_graph_validation_blocks_start_and_trigger_coexistence( + graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], +) -> None: + node_factory, graph_config = graph_init_dependencies + graph_config["nodes"] = [ + {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, + { + "id": "trigger", + "data": {"type": NodeType.TRIGGER_WEBHOOK, "title": "Webhook", "execution_type": NodeExecutionType.ROOT}, + }, + ] + graph_config["edges"] = [] + + with pytest.raises(GraphValidationError) as exc_info: + Graph.init(graph_config=graph_config, node_factory=node_factory) + + assert any(issue.code == "TRIGGER_START_NODE_CONFLICT" for issue in exc_info.value.issues) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 9700cbaf0e..2081dbf865 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -1,7 +1,8 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest +from core.workflow.graph.validation import GraphValidationError, GraphValidationIssue from models.model import App from models.workflow import Workflow from services.workflow_service import WorkflowService @@ -161,3 +162,25 @@ class TestWorkflowService: assert workflows == [] assert has_more is False mock_session.scalars.assert_called_once() + + def test_validate_graph_structure_invokes_graph_init(self, workflow_service, mock_app): + graph = {"nodes": [], "edges": []} + + with patch("services.workflow_service.Graph.init") as mock_graph_init: + workflow_service.validate_graph_structure(mock_app, graph) + + mock_graph_init.assert_called_once() + assert mock_graph_init.call_args.kwargs["graph_config"] is graph + assert "node_factory" in mock_graph_init.call_args.kwargs + + def test_validate_graph_structure_propagates_graph_errors(self, workflow_service, mock_app): + graph = {"nodes": [], "edges": []} + issue = GraphValidationIssue(code="ERR", message="invalid") + + with patch("services.workflow_service.Graph.init", side_effect=GraphValidationError([issue])): + with pytest.raises(GraphValidationError): + workflow_service.validate_graph_structure(mock_app, graph) + + def test_validate_graph_structure_requires_nodes_and_edges(self, workflow_service, mock_app): + with pytest.raises(ValueError, match="must include 'nodes' and 'edges'"): + workflow_service.validate_graph_structure(mock_app, {"nodes": []})