WIP: feat(api): do not return paused node_execution records & preserve node_execution_id across pause

This commit is contained in:
QuantumGhost 2026-01-04 23:38:40 +08:00
parent 77dc8a6edb
commit 1ad2b97169
9 changed files with 458 additions and 74 deletions

View File

@ -488,6 +488,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
WorkflowNodeExecutionModel.triggered_from == triggered_from,
WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED,
)
if self._app_id:

View File

@ -260,10 +260,33 @@ class Node(Generic[NodeDataT]):
return self._node_execution_id
def ensure_execution_id(self) -> str:
if not self._node_execution_id:
self._node_execution_id = str(uuid4())
if self._node_execution_id:
return self._node_execution_id
resumed_execution_id = self._restore_execution_id_from_runtime_state()
if resumed_execution_id:
self._node_execution_id = resumed_execution_id
return self._node_execution_id
self._node_execution_id = str(uuid4())
return self._node_execution_id
def _restore_execution_id_from_runtime_state(self) -> str | None:
graph_execution = self.graph_runtime_state.graph_execution
try:
node_executions = graph_execution.node_executions
except AttributeError:
return None
if not isinstance(node_executions, dict):
return None
node_execution = node_executions.get(self._node_id)
if node_execution is None:
return None
execution_id = node_execution.execution_id
if not execution_id:
return None
return str(execution_id)
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
return cast(NodeDataT, self._node_data_type.model_validate(data))

View File

@ -8,6 +8,7 @@ from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import HumanInputFormFilledEvent, NodeRunResult, PauseRequestedEvent
from core.workflow.node_events.base import NodeEventBase
from core.workflow.node_events.node import StreamCompletedEvent
from core.workflow.nodes.base.node import Node
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
@ -166,34 +167,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
resolved_placeholder_values=resolved_placeholder_values,
)
def _create_form(self) -> Generator[NodeEventBase, None, None] | NodeRunResult:
try:
params = FormCreateParams(
workflow_execution_id=self._workflow_execution_id,
node_id=self.id,
form_config=self._node_data,
rendered_content=self._render_form_content(),
resolved_placeholder_values=self._resolve_inputs(),
)
form_entity = self._form_repository.create_form(params)
# Create human input required event
logger.info(
"Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s",
self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
self.id,
form_entity.id,
)
yield self._form_to_pause_event(form_entity)
except Exception as e:
logger.exception("Human Input node failed to execute, node_id=%s", self.id)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
error_type="HumanInputNodeError",
)
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
def _run(self) -> Generator[NodeEventBase, None, None]:
"""
Execute the human input node.
@ -208,56 +182,69 @@ class HumanInputNode(Node[HumanInputNodeData]):
repo = self._form_repository
form = repo.get_form(self._workflow_execution_id, self.id)
if form is None:
return self._create_form()
if form.submitted:
selected_action_id = form.selected_action_id
if selected_action_id is None:
raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}")
submitted_data = form.submitted_data or {}
outputs: dict[str, Any] = dict(submitted_data)
outputs["__action_id"] = selected_action_id
rendered_content = self._render_form_content_with_outputs(
form.rendered_content,
outputs,
self._node_data.outputs_field_names(),
params = FormCreateParams(
workflow_execution_id=self._workflow_execution_id,
node_id=self.id,
form_config=self._node_data,
rendered_content=self._render_form_content_before_submission(),
resolved_placeholder_values=self._resolve_inputs(),
)
outputs["__rendered_content"] = rendered_content
form_entity = self._form_repository.create_form(params)
# Create human input required event
action_text = self._node_data.find_action_text(selected_action_id)
yield HumanInputFormFilledEvent(
rendered_content=rendered_content,
action_id=selected_action_id,
action_text=action_text,
logger.info(
"Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s",
self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
self.id,
form_entity.id,
)
yield self._form_to_pause_event(form_entity)
return
return NodeRunResult(
if form.status == HumanInputFormStatus.TIMEOUT or form.expiration_time <= naive_utc_now():
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={},
edge_source_handle="__timeout",
)
)
return
if not form.submitted:
yield self._form_to_pause_event(form)
return
selected_action_id = form.selected_action_id
if selected_action_id is None:
raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}")
submitted_data = form.submitted_data or {}
outputs: dict[str, Any] = dict(submitted_data)
outputs["__action_id"] = selected_action_id
rendered_content = self._render_form_content_with_outputs(
form.rendered_content,
outputs,
self._node_data.outputs_field_names(),
)
outputs["__rendered_content"] = rendered_content
action_text = self._node_data.find_action_text(selected_action_id)
yield HumanInputFormFilledEvent(
rendered_content=rendered_content,
action_id=selected_action_id,
action_text=action_text,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
edge_source_handle=selected_action_id,
)
)
if form.status == HumanInputFormStatus.TIMEOUT or form.expiration_time <= naive_utc_now():
outputs: dict[str, Any] = {
"__rendered_content": self._render_form_content_with_outputs(
form.rendered_content,
{},
self._node_data.outputs_field_names(),
)
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
edge_source_handle="__timeout",
)
return self._pause_with_form(form)
def _pause_with_form(self, form_entity: HumanInputFormEntity) -> Generator[NodeEventBase, None, None]:
yield self._form_to_pause_event(form_entity)
def _render_form_content(self) -> str:
def _render_form_content_before_submission(self) -> str:
"""
Process form content by substituting variables.

View File

@ -13,6 +13,7 @@ from typing import Any
from sqlalchemy.orm import sessionmaker
from core.workflow.enums import WorkflowNodeExecutionStatus
from extensions.logstore.aliyun_logstore import AliyunLogStore
from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
@ -199,8 +200,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
reverse=True,
)
if deduplicated_results:
return _dict_to_workflow_node_execution_model(deduplicated_results[0])
for row in deduplicated_results:
model = _dict_to_workflow_node_execution_model(row)
if model.status != WorkflowNodeExecutionStatus.PAUSED:
return model
return None
@ -293,6 +296,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
if model and model.id: # Ensure model is valid
models.append(model)
models = [model for model in models if model.status != WorkflowNodeExecutionStatus.PAUSED]
# Sort by index DESC for trace visualization
models.sort(key=lambda x: x.index, reverse=True)

View File

@ -13,6 +13,7 @@ from sqlalchemy import asc, delete, desc, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.enums import WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
@ -76,6 +77,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
WorkflowNodeExecutionModel.app_id == app_id,
WorkflowNodeExecutionModel.workflow_id == workflow_id,
WorkflowNodeExecutionModel.node_id == node_id,
WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED,
)
.order_by(desc(WorkflowNodeExecutionModel.created_at))
.limit(1)
@ -109,6 +111,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_id,
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED,
).order_by(asc(WorkflowNodeExecutionModel.created_at))
with self._session_maker() as session:

View File

@ -0,0 +1,336 @@
import time
import uuid
from datetime import timedelta
from unittest.mock import MagicMock
import pytest
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowType
from core.workflow.graph import Graph
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from models import Account
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.model import App, AppMode, IconType
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun
def _mock_form_repository_without_submission() -> HumanInputFormRepository:
repo = MagicMock(spec=HumanInputFormRepository)
form_entity = MagicMock(spec=HumanInputFormEntity)
form_entity.id = "test-form-id"
form_entity.web_app_token = "test-form-token"
form_entity.recipients = []
form_entity.rendered_content = "rendered"
form_entity.submitted = False
repo.create_form.return_value = form_entity
repo.get_form.return_value = None
return repo
def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository:
repo = MagicMock(spec=HumanInputFormRepository)
form_entity = MagicMock(spec=HumanInputFormEntity)
form_entity.id = "test-form-id"
form_entity.web_app_token = "test-form-token"
form_entity.recipients = []
form_entity.rendered_content = "rendered"
form_entity.submitted = True
form_entity.selected_action_id = action_id
form_entity.submitted_data = {}
form_entity.status = HumanInputFormStatus.WAITING
form_entity.expiration_time = naive_utc_now() + timedelta(hours=1)
repo.get_form.return_value = form_entity
return repo
def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(
workflow_execution_id=workflow_execution_id,
app_id=app_id,
workflow_id=workflow_id,
user_id=user_id,
),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _build_graph(
runtime_state: GraphRuntimeState,
tenant_id: str,
app_id: str,
workflow_id: str,
user_id: str,
form_repository: HumanInputFormRepository,
) -> Graph:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
params = GraphInitParams(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from="account",
invoke_from="debugger",
call_depth=0,
)
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
id="start",
config={"id": "start", "data": start_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
human_data = HumanInputNodeData(
title="human",
form_content="Awaiting human input",
inputs=[],
user_actions=[
UserAction(id="continue", title="Continue"),
],
)
human_node = HumanInputNode(
id="human",
config={"id": "human", "data": human_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
form_repository=form_repository,
)
end_data = EndNodeData(
title="end",
outputs=[],
desc=None,
)
end_node = EndNode(
id="end",
config={"id": "end", "data": end_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
return (
Graph.new()
.add_root(start_node)
.add_node(human_node)
.add_node(end_node, from_node_id="human", source_handle="continue")
.build()
)
def _build_generate_entity(
tenant_id: str,
app_id: str,
workflow_id: str,
workflow_execution_id: str,
user_id: str,
) -> WorkflowAppGenerateEntity:
app_config = WorkflowUIBasedAppConfig(
tenant_id=tenant_id,
app_id=app_id,
app_mode=AppMode.WORKFLOW,
workflow_id=workflow_id,
)
return WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
inputs={},
files=[],
user_id=user_id,
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
workflow_execution_id=workflow_execution_id,
)
class TestHumanInputResumeNodeExecutionIntegration:
@pytest.fixture(autouse=True)
def setup_test_data(self, db_session_with_containers: Session):
tenant = Tenant(
name="Test Tenant",
status="normal",
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
account = Account(
email="test@example.com",
name="Test User",
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(tenant_join)
db_session_with_containers.commit()
account.current_tenant = tenant
app = App(
tenant_id=tenant.id,
name="Test App",
description="",
mode=AppMode.WORKFLOW.value,
icon_type=IconType.EMOJI.value,
icon="rocket",
icon_background="#4ECDC4",
enable_site=False,
enable_api=False,
api_rpm=0,
api_rph=0,
is_demo=False,
is_public=False,
is_universal=False,
max_active_requests=None,
created_by=account.id,
updated_by=account.id,
)
db_session_with_containers.add(app)
db_session_with_containers.commit()
workflow = Workflow(
tenant_id=tenant.id,
app_id=app.id,
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features='{"file_upload": {"enabled": false}}',
created_by=account.id,
created_at=naive_utc_now(),
)
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
self.session = db_session_with_containers
self.tenant = tenant
self.account = account
self.app = app
self.workflow = workflow
yield
self.session.execute(delete(WorkflowNodeExecutionModel))
self.session.execute(delete(WorkflowRun))
self.session.execute(delete(Workflow).where(Workflow.id == self.workflow.id))
self.session.execute(delete(App).where(App.id == self.app.id))
self.session.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == self.tenant.id))
self.session.execute(delete(Account).where(Account.id == self.account.id))
self.session.execute(delete(Tenant).where(Tenant.id == self.tenant.id))
self.session.commit()
def _build_persistence_layer(self, execution_id: str) -> WorkflowPersistenceLayer:
generate_entity = _build_generate_entity(
tenant_id=self.tenant.id,
app_id=self.app.id,
workflow_id=self.workflow.id,
workflow_execution_id=execution_id,
user_id=self.account.id,
)
execution_repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=self.session.get_bind(),
user=self.account,
app_id=self.app.id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
node_execution_repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=self.session.get_bind(),
user=self.account,
app_id=self.app.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
return WorkflowPersistenceLayer(
application_generate_entity=generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=self.workflow.id,
workflow_type=WorkflowType.WORKFLOW,
version=self.workflow.version,
graph_data=self.workflow.graph_dict,
),
workflow_execution_repository=execution_repo,
workflow_node_execution_repository=node_execution_repo,
)
def _run_graph(self, graph: Graph, runtime_state: GraphRuntimeState, execution_id: str) -> None:
engine = GraphEngine(
workflow_id=self.workflow.id,
graph=graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
engine.layer(self._build_persistence_layer(execution_id))
for _ in engine.run():
continue
def test_resume_human_input_does_not_create_duplicate_node_execution(self):
execution_id = str(uuid.uuid4())
runtime_state = _build_runtime_state(
workflow_execution_id=execution_id,
app_id=self.app.id,
workflow_id=self.workflow.id,
user_id=self.account.id,
)
pause_repo = _mock_form_repository_without_submission()
paused_graph = _build_graph(
runtime_state,
self.tenant.id,
self.app.id,
self.workflow.id,
self.account.id,
pause_repo,
)
self._run_graph(paused_graph, runtime_state, execution_id)
snapshot = runtime_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
resume_repo = _mock_form_repository_with_submission(action_id="continue")
resumed_graph = _build_graph(
resumed_state,
self.tenant.id,
self.app.id,
self.workflow.id,
self.account.id,
resume_repo,
)
self._run_graph(resumed_graph, resumed_state, execution_id)
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id == execution_id,
WorkflowNodeExecutionModel.node_id == "human",
)
records = self.session.execute(stmt).scalars().all()
assert len(records) == 1
assert records[0].status != "paused"
assert records[0].triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
assert records[0].created_by_role == CreatorUserRole.ACCOUNT

View File

@ -465,6 +465,27 @@ class TestWorkflowRunService:
db.session.add(node_execution)
node_executions.append(node_execution)
paused_node_execution = WorkflowNodeExecutionModel(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow_run.workflow_id,
triggered_from="workflow-run",
workflow_run_id=workflow_run.id,
index=99,
node_id="node_paused",
node_type="human_input",
title="Paused Node",
inputs=json.dumps({"input": "paused"}),
process_data=json.dumps({"process": "paused"}),
status="paused",
elapsed_time=0.5,
execution_metadata=json.dumps({"tokens": 0}),
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
)
db.session.add(paused_node_execution)
db.session.commit()
# Act: Execute the method under test
@ -477,6 +498,7 @@ class TestWorkflowRunService:
# Verify node execution properties
for node_execution in result:
assert node_execution.status != "paused"
assert node_execution.tenant_id == app.tenant_id
assert node_execution.app_id == app.id
assert node_execution.workflow_run_id == workflow_run.id

View File

@ -9,7 +9,7 @@ import pytest
from pydantic import ValidationError
from core.workflow.entities import GraphInitParams
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
from core.workflow.node_events import PauseRequestedEvent
from core.workflow.node_events.node import StreamCompletedEvent
from core.workflow.nodes.human_input.entities import (
EmailDeliveryConfig,

View File

@ -5,6 +5,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.workflow.enums import WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionModel
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
@ -52,6 +53,9 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
call_args = mock_session.scalar.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
compiled = call_args.compile()
assert WorkflowNodeExecutionStatus.PAUSED in compiled.params.values()
def test_get_node_last_execution_not_found(self, repository):
"""Test getting the last execution for a node when it doesn't exist."""
# Arrange
@ -93,6 +97,9 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
call_args = mock_session.execute.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
compiled = call_args.compile()
assert WorkflowNodeExecutionStatus.PAUSED in compiled.params.values()
def test_get_executions_by_workflow_run_empty(self, repository):
"""Test getting executions for a workflow run when none exist."""
# Arrange