mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
WIP: feat(api): do not return paused node_execution records & preserve node_execution_id across pause
This commit is contained in:
parent
77dc8a6edb
commit
1ad2b97169
@ -488,6 +488,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecutionModel.triggered_from == triggered_from,
|
||||
WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
|
||||
@ -260,10 +260,33 @@ class Node(Generic[NodeDataT]):
|
||||
return self._node_execution_id
|
||||
|
||||
def ensure_execution_id(self) -> str:
|
||||
if not self._node_execution_id:
|
||||
self._node_execution_id = str(uuid4())
|
||||
if self._node_execution_id:
|
||||
return self._node_execution_id
|
||||
|
||||
resumed_execution_id = self._restore_execution_id_from_runtime_state()
|
||||
if resumed_execution_id:
|
||||
self._node_execution_id = resumed_execution_id
|
||||
return self._node_execution_id
|
||||
|
||||
self._node_execution_id = str(uuid4())
|
||||
return self._node_execution_id
|
||||
|
||||
def _restore_execution_id_from_runtime_state(self) -> str | None:
|
||||
graph_execution = self.graph_runtime_state.graph_execution
|
||||
try:
|
||||
node_executions = graph_execution.node_executions
|
||||
except AttributeError:
|
||||
return None
|
||||
if not isinstance(node_executions, dict):
|
||||
return None
|
||||
node_execution = node_executions.get(self._node_id)
|
||||
if node_execution is None:
|
||||
return None
|
||||
execution_id = node_execution.execution_id
|
||||
if not execution_id:
|
||||
return None
|
||||
return str(execution_id)
|
||||
|
||||
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
||||
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import HumanInputFormFilledEvent, NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.node_events.base import NodeEventBase
|
||||
from core.workflow.node_events.node import StreamCompletedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
@ -166,34 +167,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
resolved_placeholder_values=resolved_placeholder_values,
|
||||
)
|
||||
|
||||
def _create_form(self) -> Generator[NodeEventBase, None, None] | NodeRunResult:
|
||||
try:
|
||||
params = FormCreateParams(
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
node_id=self.id,
|
||||
form_config=self._node_data,
|
||||
rendered_content=self._render_form_content(),
|
||||
resolved_placeholder_values=self._resolve_inputs(),
|
||||
)
|
||||
form_entity = self._form_repository.create_form(params)
|
||||
# Create human input required event
|
||||
|
||||
logger.info(
|
||||
"Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s",
|
||||
self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
|
||||
self.id,
|
||||
form_entity.id,
|
||||
)
|
||||
yield self._form_to_pause_event(form_entity)
|
||||
except Exception as e:
|
||||
logger.exception("Human Input node failed to execute, node_id=%s", self.id)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
error_type="HumanInputNodeError",
|
||||
)
|
||||
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Execute the human input node.
|
||||
|
||||
@ -208,56 +182,69 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
repo = self._form_repository
|
||||
form = repo.get_form(self._workflow_execution_id, self.id)
|
||||
if form is None:
|
||||
return self._create_form()
|
||||
|
||||
if form.submitted:
|
||||
selected_action_id = form.selected_action_id
|
||||
if selected_action_id is None:
|
||||
raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}")
|
||||
submitted_data = form.submitted_data or {}
|
||||
outputs: dict[str, Any] = dict(submitted_data)
|
||||
outputs["__action_id"] = selected_action_id
|
||||
rendered_content = self._render_form_content_with_outputs(
|
||||
form.rendered_content,
|
||||
outputs,
|
||||
self._node_data.outputs_field_names(),
|
||||
params = FormCreateParams(
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
node_id=self.id,
|
||||
form_config=self._node_data,
|
||||
rendered_content=self._render_form_content_before_submission(),
|
||||
resolved_placeholder_values=self._resolve_inputs(),
|
||||
)
|
||||
outputs["__rendered_content"] = rendered_content
|
||||
form_entity = self._form_repository.create_form(params)
|
||||
# Create human input required event
|
||||
|
||||
action_text = self._node_data.find_action_text(selected_action_id)
|
||||
|
||||
yield HumanInputFormFilledEvent(
|
||||
rendered_content=rendered_content,
|
||||
action_id=selected_action_id,
|
||||
action_text=action_text,
|
||||
logger.info(
|
||||
"Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s",
|
||||
self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
|
||||
self.id,
|
||||
form_entity.id,
|
||||
)
|
||||
yield self._form_to_pause_event(form_entity)
|
||||
return
|
||||
|
||||
return NodeRunResult(
|
||||
if form.status == HumanInputFormStatus.TIMEOUT or form.expiration_time <= naive_utc_now():
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={},
|
||||
edge_source_handle="__timeout",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if not form.submitted:
|
||||
yield self._form_to_pause_event(form)
|
||||
return
|
||||
|
||||
selected_action_id = form.selected_action_id
|
||||
if selected_action_id is None:
|
||||
raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}")
|
||||
submitted_data = form.submitted_data or {}
|
||||
outputs: dict[str, Any] = dict(submitted_data)
|
||||
outputs["__action_id"] = selected_action_id
|
||||
rendered_content = self._render_form_content_with_outputs(
|
||||
form.rendered_content,
|
||||
outputs,
|
||||
self._node_data.outputs_field_names(),
|
||||
)
|
||||
outputs["__rendered_content"] = rendered_content
|
||||
|
||||
action_text = self._node_data.find_action_text(selected_action_id)
|
||||
|
||||
yield HumanInputFormFilledEvent(
|
||||
rendered_content=rendered_content,
|
||||
action_id=selected_action_id,
|
||||
action_text=action_text,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
edge_source_handle=selected_action_id,
|
||||
)
|
||||
)
|
||||
|
||||
if form.status == HumanInputFormStatus.TIMEOUT or form.expiration_time <= naive_utc_now():
|
||||
outputs: dict[str, Any] = {
|
||||
"__rendered_content": self._render_form_content_with_outputs(
|
||||
form.rendered_content,
|
||||
{},
|
||||
self._node_data.outputs_field_names(),
|
||||
)
|
||||
}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
edge_source_handle="__timeout",
|
||||
)
|
||||
|
||||
return self._pause_with_form(form)
|
||||
|
||||
def _pause_with_form(self, form_entity: HumanInputFormEntity) -> Generator[NodeEventBase, None, None]:
|
||||
yield self._form_to_pause_event(form_entity)
|
||||
|
||||
def _render_form_content(self) -> str:
|
||||
def _render_form_content_before_submission(self) -> str:
|
||||
"""
|
||||
Process form content by substituting variables.
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from typing import Any
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
@ -199,8 +200,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
if deduplicated_results:
|
||||
return _dict_to_workflow_node_execution_model(deduplicated_results[0])
|
||||
for row in deduplicated_results:
|
||||
model = _dict_to_workflow_node_execution_model(row)
|
||||
if model.status != WorkflowNodeExecutionStatus.PAUSED:
|
||||
return model
|
||||
|
||||
return None
|
||||
|
||||
@ -293,6 +296,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
||||
if model and model.id: # Ensure model is valid
|
||||
models.append(model)
|
||||
|
||||
models = [model for model in models if model.status != WorkflowNodeExecutionStatus.PAUSED]
|
||||
|
||||
# Sort by index DESC for trace visualization
|
||||
models.sort(key=lambda x: x.index, reverse=True)
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from sqlalchemy import asc, delete, desc, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
|
||||
@ -76,6 +77,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
WorkflowNodeExecutionModel.workflow_id == workflow_id,
|
||||
WorkflowNodeExecutionModel.node_id == node_id,
|
||||
WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED,
|
||||
)
|
||||
.order_by(desc(WorkflowNodeExecutionModel.created_at))
|
||||
.limit(1)
|
||||
@ -109,6 +111,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED,
|
||||
).order_by(asc(WorkflowNodeExecutionModel.created_at))
|
||||
|
||||
with self._session_maker() as session:
|
||||
|
||||
@ -0,0 +1,336 @@
|
||||
import time
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.model import App, AppMode, IconType
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun
|
||||
|
||||
|
||||
def _mock_form_repository_without_submission() -> HumanInputFormRepository:
|
||||
repo = MagicMock(spec=HumanInputFormRepository)
|
||||
form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
form_entity.id = "test-form-id"
|
||||
form_entity.web_app_token = "test-form-token"
|
||||
form_entity.recipients = []
|
||||
form_entity.rendered_content = "rendered"
|
||||
form_entity.submitted = False
|
||||
repo.create_form.return_value = form_entity
|
||||
repo.get_form.return_value = None
|
||||
return repo
|
||||
|
||||
|
||||
def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository:
|
||||
repo = MagicMock(spec=HumanInputFormRepository)
|
||||
form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
form_entity.id = "test-form-id"
|
||||
form_entity.web_app_token = "test-form-token"
|
||||
form_entity.recipients = []
|
||||
form_entity.rendered_content = "rendered"
|
||||
form_entity.submitted = True
|
||||
form_entity.selected_action_id = action_id
|
||||
form_entity.submitted_data = {}
|
||||
form_entity.status = HumanInputFormStatus.WAITING
|
||||
form_entity.expiration_time = naive_utc_now() + timedelta(hours=1)
|
||||
repo.get_form.return_value = form_entity
|
||||
return repo
|
||||
|
||||
|
||||
def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
user_id=user_id,
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _build_graph(
|
||||
runtime_state: GraphRuntimeState,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
form_repository: HumanInputFormRepository,
|
||||
) -> Graph:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
start_data = StartNodeData(title="start", variables=[])
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": start_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="human",
|
||||
form_content="Awaiting human input",
|
||||
inputs=[],
|
||||
user_actions=[
|
||||
UserAction(id="continue", title="Continue"),
|
||||
],
|
||||
)
|
||||
human_node = HumanInputNode(
|
||||
id="human",
|
||||
config={"id": "human", "data": human_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=form_repository,
|
||||
)
|
||||
|
||||
end_data = EndNodeData(
|
||||
title="end",
|
||||
outputs=[],
|
||||
desc=None,
|
||||
)
|
||||
end_node = EndNode(
|
||||
id="end",
|
||||
config={"id": "end", "data": end_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
return (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(human_node)
|
||||
.add_node(end_node, from_node_id="human", source_handle="continue")
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
def _build_generate_entity(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
workflow_execution_id: str,
|
||||
user_id: str,
|
||||
) -> WorkflowAppGenerateEntity:
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
return WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user_id,
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
)
|
||||
|
||||
|
||||
class TestHumanInputResumeNodeExecutionIntegration:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_data(self, db_session_with_containers: Session):
|
||||
tenant = Tenant(
|
||||
name="Test Tenant",
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account = Account(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.current_tenant = tenant
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App",
|
||||
description="",
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
icon_type=IconType.EMOJI.value,
|
||||
icon="rocket",
|
||||
icon_background="#4ECDC4",
|
||||
enable_site=False,
|
||||
enable_api=False,
|
||||
api_rpm=0,
|
||||
api_rph=0,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
max_active_requests=None,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow = Workflow(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=account.id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
self.session = db_session_with_containers
|
||||
self.tenant = tenant
|
||||
self.account = account
|
||||
self.app = app
|
||||
self.workflow = workflow
|
||||
|
||||
yield
|
||||
|
||||
self.session.execute(delete(WorkflowNodeExecutionModel))
|
||||
self.session.execute(delete(WorkflowRun))
|
||||
self.session.execute(delete(Workflow).where(Workflow.id == self.workflow.id))
|
||||
self.session.execute(delete(App).where(App.id == self.app.id))
|
||||
self.session.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == self.tenant.id))
|
||||
self.session.execute(delete(Account).where(Account.id == self.account.id))
|
||||
self.session.execute(delete(Tenant).where(Tenant.id == self.tenant.id))
|
||||
self.session.commit()
|
||||
|
||||
def _build_persistence_layer(self, execution_id: str) -> WorkflowPersistenceLayer:
|
||||
generate_entity = _build_generate_entity(
|
||||
tenant_id=self.tenant.id,
|
||||
app_id=self.app.id,
|
||||
workflow_id=self.workflow.id,
|
||||
workflow_execution_id=execution_id,
|
||||
user_id=self.account.id,
|
||||
)
|
||||
execution_repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=self.session.get_bind(),
|
||||
user=self.account,
|
||||
app_id=self.app.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
node_execution_repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=self.session.get_bind(),
|
||||
user=self.account,
|
||||
app_id=self.app.id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
return WorkflowPersistenceLayer(
|
||||
application_generate_entity=generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id=self.workflow.id,
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
version=self.workflow.version,
|
||||
graph_data=self.workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=execution_repo,
|
||||
workflow_node_execution_repository=node_execution_repo,
|
||||
)
|
||||
|
||||
def _run_graph(self, graph: Graph, runtime_state: GraphRuntimeState, execution_id: str) -> None:
|
||||
engine = GraphEngine(
|
||||
workflow_id=self.workflow.id,
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
engine.layer(self._build_persistence_layer(execution_id))
|
||||
for _ in engine.run():
|
||||
continue
|
||||
|
||||
def test_resume_human_input_does_not_create_duplicate_node_execution(self):
|
||||
execution_id = str(uuid.uuid4())
|
||||
runtime_state = _build_runtime_state(
|
||||
workflow_execution_id=execution_id,
|
||||
app_id=self.app.id,
|
||||
workflow_id=self.workflow.id,
|
||||
user_id=self.account.id,
|
||||
)
|
||||
pause_repo = _mock_form_repository_without_submission()
|
||||
paused_graph = _build_graph(
|
||||
runtime_state,
|
||||
self.tenant.id,
|
||||
self.app.id,
|
||||
self.workflow.id,
|
||||
self.account.id,
|
||||
pause_repo,
|
||||
)
|
||||
self._run_graph(paused_graph, runtime_state, execution_id)
|
||||
|
||||
snapshot = runtime_state.dumps()
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
resume_repo = _mock_form_repository_with_submission(action_id="continue")
|
||||
resumed_graph = _build_graph(
|
||||
resumed_state,
|
||||
self.tenant.id,
|
||||
self.app.id,
|
||||
self.workflow.id,
|
||||
self.account.id,
|
||||
resume_repo,
|
||||
)
|
||||
self._run_graph(resumed_graph, resumed_state, execution_id)
|
||||
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == execution_id,
|
||||
WorkflowNodeExecutionModel.node_id == "human",
|
||||
)
|
||||
records = self.session.execute(stmt).scalars().all()
|
||||
assert len(records) == 1
|
||||
assert records[0].status != "paused"
|
||||
assert records[0].triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
assert records[0].created_by_role == CreatorUserRole.ACCOUNT
|
||||
@ -465,6 +465,27 @@ class TestWorkflowRunService:
|
||||
db.session.add(node_execution)
|
||||
node_executions.append(node_execution)
|
||||
|
||||
paused_node_execution = WorkflowNodeExecutionModel(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=workflow_run.id,
|
||||
index=99,
|
||||
node_id="node_paused",
|
||||
node_type="human_input",
|
||||
title="Paused Node",
|
||||
inputs=json.dumps({"input": "paused"}),
|
||||
process_data=json.dumps({"process": "paused"}),
|
||||
status="paused",
|
||||
elapsed_time=0.5,
|
||||
execution_metadata=json.dumps({"tokens": 0}),
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(paused_node_execution)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
@ -477,6 +498,7 @@ class TestWorkflowRunService:
|
||||
|
||||
# Verify node execution properties
|
||||
for node_execution in result:
|
||||
assert node_execution.status != "paused"
|
||||
assert node_execution.tenant_id == app.tenant_id
|
||||
assert node_execution.app_id == app.id
|
||||
assert node_execution.workflow_run_id == workflow_run.id
|
||||
|
||||
@ -9,7 +9,7 @@ import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.node_events import PauseRequestedEvent
|
||||
from core.workflow.node_events.node import StreamCompletedEvent
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
|
||||
@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
|
||||
@ -52,6 +53,9 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
|
||||
call_args = mock_session.scalar.call_args[0][0]
|
||||
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
|
||||
|
||||
compiled = call_args.compile()
|
||||
assert WorkflowNodeExecutionStatus.PAUSED in compiled.params.values()
|
||||
|
||||
def test_get_node_last_execution_not_found(self, repository):
|
||||
"""Test getting the last execution for a node when it doesn't exist."""
|
||||
# Arrange
|
||||
@ -93,6 +97,9 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
|
||||
call_args = mock_session.execute.call_args[0][0]
|
||||
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
|
||||
|
||||
compiled = call_args.compile()
|
||||
assert WorkflowNodeExecutionStatus.PAUSED in compiled.params.values()
|
||||
|
||||
def test_get_executions_by_workflow_run_empty(self, repository):
|
||||
"""Test getting executions for a workflow run when none exist."""
|
||||
# Arrange
|
||||
|
||||
Loading…
Reference in New Issue
Block a user