WIP: human input timeout

This commit is contained in:
QuantumGhost 2025-12-25 17:41:53 +08:00
parent 203a3a68af
commit 5d0dd329f2
13 changed files with 241 additions and 18 deletions

View File

@ -82,6 +82,12 @@ class AppExecutionConfig(BaseSettings):
default=0,
)
HITL_GLOBAL_TIMEOUT_HOURS: PositiveInt = Field(
description="Maximum hours a workflow run can stay paused waiting for human input before global timeout.",
default=24 * 7,
ge=1,
)
class CodeExecutionSandboxConfig(BaseSettings):
"""

View File

@ -28,6 +28,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowSucceededEvent,
)
from core.workflow.entities import GraphInitParams
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import (
@ -55,7 +56,6 @@ from core.workflow.graph_events import (
NodeRunSucceededEvent,
)
from core.workflow.graph_events.graph import GraphRunAbortedEvent
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
@ -63,9 +63,9 @@ from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
from models.enums import UserFrom
from models.workflow import Workflow
from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
logger = logging.getLogger(__name__)

View File

@ -13,6 +13,7 @@ from core.workflow.nodes.human_input.entities import (
EmailRecipient,
ExternalRecipient,
FormDefinition,
HumanInputFormStatus,
HumanInputNodeData,
MemberRecipient,
WebAppDeliveryMethod,
@ -106,6 +107,14 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity):
def submitted(self) -> bool:
return self._form_model.submitted_at is not None
@property
def status(self) -> HumanInputFormStatus:
return self._form_model.status
@property
def expiration_time(self) -> datetime:
return self._form_model.expiration_time
@dataclasses.dataclass(frozen=True)
class HumanInputFormRecord:
@ -116,6 +125,7 @@ class HumanInputFormRecord:
definition: FormDefinition
rendered_content: str
expiration_time: datetime
status: HumanInputFormStatus
selected_action_id: str | None
submitted_data: Mapping[str, Any] | None
submitted_at: datetime | None
@ -142,6 +152,7 @@ class HumanInputFormRecord:
definition=FormDefinition.model_validate_json(form_model.form_definition),
rendered_content=form_model.rendered_content,
expiration_time=form_model.expiration_time,
status=form_model.status,
selected_action_id=form_model.selected_action_id,
submitted_data=json.loads(form_model.submitted_data) if form_model.submitted_data else None,
submitted_at=form_model.submitted_at,
@ -296,6 +307,8 @@ class HumanInputFormRepositoryImpl:
with self._session_factory(expire_on_commit=False) as session, session.begin():
# Generate unique form ID
form_id = str(uuidv7())
start_time = naive_utc_now()
node_expiration = form_config.expiration_time(start_time)
form_definition = FormDefinition(
form_content=form_config.form_content,
inputs=form_config.inputs,
@ -312,7 +325,8 @@ class HumanInputFormRepositoryImpl:
node_id=params.node_id,
form_definition=form_definition.model_dump_json(),
rendered_content=params.rendered_content,
expiration_time=form_config.expiration_time(naive_utc_now()),
expiration_time=node_expiration,
created_at=start_time,
)
session.add(form_model)
recipient_models: list[HumanInputFormRecipient] = []
@ -404,6 +418,7 @@ class HumanInputFormSubmissionRepository:
form_model.selected_action_id = selected_action_id
form_model.submitted_data = json.dumps(form_data)
form_model.submitted_at = naive_utc_now()
form_model.status = HumanInputFormStatus.SUBMITTED
form_model.submission_user_id = submission_user_id
form_model.submission_end_user_id = submission_end_user_id
form_model.completed_by_recipient_id = recipient_id
@ -415,3 +430,29 @@ class HumanInputFormSubmissionRepository:
session.refresh(recipient_model)
return HumanInputFormRecord.from_models(form_model, recipient_model)
def mark_timeout(self, *, form_id: str, reason: str | None = None) -> HumanInputFormRecord:
with self._session_factory(expire_on_commit=False) as session, session.begin():
form_model = session.get(HumanInputForm, form_id)
if form_model is None:
raise FormNotFoundError(f"form not found, id={form_id}")
# already handled or submitted
if form_model.status == HumanInputFormStatus.TIMEOUT:
return HumanInputFormRecord.from_models(form_model, None)
if form_model.submitted_at is not None or form_model.status == HumanInputFormStatus.SUBMITTED:
raise FormNotFoundError(f"form already submitted, id={form_id}")
form_model.status = HumanInputFormStatus.TIMEOUT
form_model.selected_action_id = None
form_model.submitted_data = None
form_model.submission_user_id = None
form_model.submission_end_user_id = None
form_model.completed_by_recipient_id = None
# Reason is recorded in status/error downstream; not stored on form.
session.add(form_model)
session.flush()
session.refresh(form_model)
return HumanInputFormRecord.from_models(form_model, None)

View File

@ -15,8 +15,9 @@ from core.workflow.repositories.human_input_form_repository import (
)
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from .entities import HumanInputNodeData, PlaceholderType
from .entities import HumanInputFormStatus, HumanInputNodeData, PlaceholderType
if TYPE_CHECKING:
from core.workflow.entities.graph_init_params import GraphInitParams
@ -221,6 +222,14 @@ class HumanInputNode(Node[HumanInputNodeData]):
edge_source_handle=selected_action_id,
)
if form.status == HumanInputFormStatus.TIMEOUT or form.expiration_time <= naive_utc_now():
outputs: dict[str, Any] = {"__rendered_content": form.rendered_content}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
edge_source_handle="__timeout",
)
return self._pause_with_form(form)
def _pause_with_form(self, form_entity: HumanInputFormEntity) -> Generator[NodeEventBase, None, None]:

View File

@ -1,9 +1,10 @@
import abc
import dataclasses
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Protocol
from core.workflow.nodes.human_input.entities import HumanInputNodeData
from core.workflow.nodes.human_input.entities import HumanInputFormStatus, HumanInputNodeData
class HumanInputError(Exception):
@ -82,6 +83,18 @@ class HumanInputFormEntity(abc.ABC):
"""Whether the form has been submitted."""
...
@property
@abc.abstractmethod
def status(self) -> HumanInputFormStatus:
"""Current status of the form."""
...
@property
@abc.abstractmethod
def expiration_time(self) -> datetime:
"""When the form expires."""
...
class HumanInputFormRecipientEntity(abc.ABC):
@property

View File

@ -9,7 +9,8 @@ from core.repositories.human_input_reposotiry import (
HumanInputFormRecord,
HumanInputFormSubmissionRepository,
)
from core.workflow.nodes.human_input.entities import FormDefinition
from core.workflow.nodes.human_input.entities import FormDefinition, HumanInputFormStatus
from libs.datetime_utils import naive_utc_now
from libs.exception import BaseHTTPException
from models.account import Account
from models.human_input import RecipientType
@ -48,6 +49,14 @@ class Form:
def recipient_type(self) -> RecipientType | None:
return self._record.recipient_type
@property
def status(self) -> HumanInputFormStatus:
return self._record.status
@property
def expiration_time(self):
return self._record.expiration_time
class HumanInputError(Exception):
pass
@ -80,6 +89,14 @@ class WebAppDeliveryNotEnabledError(HumanInputError, BaseException):
pass
class FormExpiredError(HumanInputError, BaseHTTPException):
error_code = "human_input_form_expired"
code = 412
def __init__(self, form_id: str):
super().__init__(description=f"This form has expired, form_id={form_id}")
logger = logging.getLogger(__name__)
@ -134,7 +151,7 @@ class HumanInputService:
if form is None:
raise WebAppDeliveryNotEnabledError()
self._ensure_not_submitted(form)
self._ensure_form_active(form)
self._validate_submission(form=form, selected_action_id=selected_action_id, form_data=form_data)
result = self._form_repository.mark_submitted(
@ -160,7 +177,7 @@ class HumanInputService:
if form is None or form.recipient_type != recipient_type:
raise WebAppDeliveryNotEnabledError()
self._ensure_not_submitted(form)
self._ensure_form_active(form)
self._validate_submission(form=form, selected_action_id=selected_action_id, form_data=form_data)
result = self._form_repository.mark_submitted(
@ -174,6 +191,15 @@ class HumanInputService:
self._enqueue_resume(result.workflow_run_id)
def _ensure_form_active(self, form: Form) -> None:
if form.submitted:
raise FormSubmittedError(form.id)
if form.status == HumanInputFormStatus.TIMEOUT:
raise FormExpiredError(form.id)
now = naive_utc_now()
if form.expiration_time <= now:
raise FormExpiredError(form.id)
def _ensure_not_submitted(self, form: Form) -> None:
if form.submitted:
raise FormSubmittedError(form.id)

View File

@ -0,0 +1,110 @@
import logging
from datetime import timedelta
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.repositories.human_input_reposotiry import HumanInputFormSubmissionRepository
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes.human_input.entities import FormDefinition, HumanInputFormStatus, TimeoutUnit
from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.datetime_utils import ensure_naive_utc, naive_utc_now
from models.human_input import HumanInputForm
from models.workflow import WorkflowPause, WorkflowRun
from services.human_input_service import HumanInputService
logger = logging.getLogger(__name__)
def _calculate_node_deadline(definition: FormDefinition, created_at, *, start_time=None):
start = start_time or created_at
if definition.timeout_unit == TimeoutUnit.HOUR:
return start + timedelta(hours=definition.timeout)
if definition.timeout_unit == TimeoutUnit.DAY:
return start + timedelta(days=definition.timeout)
raise AssertionError("unknown timeout unit.")
def _is_global_timeout(form_model: HumanInputForm, global_timeout_hours: int) -> bool:
if global_timeout_hours <= 0:
return False
form_definition = FormDefinition.model_validate_json(form_model.form_definition)
created_at = ensure_naive_utc(form_model.created_at)
expiration_time = ensure_naive_utc(form_model.expiration_time)
node_deadline = _calculate_node_deadline(form_definition, created_at)
global_deadline = created_at + timedelta(hours=global_timeout_hours)
return global_deadline <= node_deadline and expiration_time <= global_deadline
def _handle_global_timeout(*, form_id: str, workflow_run_id: str, node_id: str, session_factory: sessionmaker) -> None:
now = naive_utc_now()
with session_factory() as session, session.begin():
workflow_run = session.get(WorkflowRun, workflow_run_id)
if workflow_run is not None:
workflow_run.status = WorkflowExecutionStatus.FAILED
workflow_run.error = f"Human input global timeout at node {node_id}"
workflow_run.finished_at = now
session.add(workflow_run)
pause_model = session.scalar(select(WorkflowPause).where(WorkflowPause.workflow_run_id == workflow_run_id))
if pause_model is not None:
try:
storage.delete(pause_model.state_object_key)
except Exception:
logger.exception(
"Failed to delete pause state object for workflow_run_id=%s, pause_id=%s",
workflow_run_id,
pause_model.id,
)
pause_model.resumed_at = now
session.add(pause_model)
@shared_task(name="human_input_form_timeout.check_and_resume")
def check_and_handle_human_input_timeouts(limit: int = 100) -> None:
"""Scan for expired human input forms and resume or end workflows."""
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
form_repo = HumanInputFormSubmissionRepository(session_factory)
service = HumanInputService(session_factory, form_repository=form_repo)
now = naive_utc_now()
global_timeout_hours = int(getattr(dify_config, "HITL_GLOBAL_TIMEOUT_HOURS", 0) or 0)
with session_factory() as session:
stmt = (
select(HumanInputForm)
.where(
HumanInputForm.status == HumanInputFormStatus.WAITING,
HumanInputForm.expiration_time <= now,
)
.limit(limit)
)
expired_forms = session.scalars(stmt).all()
for form_model in expired_forms:
try:
is_global = _is_global_timeout(form_model, global_timeout_hours)
record = form_repo.mark_timeout(
form_id=form_model.id,
reason="global_timeout" if is_global else "node_timeout",
)
if is_global:
_handle_global_timeout(
form_id=record.form_id,
workflow_run_id=record.workflow_run_id,
node_id=record.node_id,
session_factory=session_factory,
)
else:
service._enqueue_resume(record.workflow_run_id)
except Exception:
logger.exception(
"Failed to handle timeout for form_id=%s workflow_run_id=%s",
getattr(form_model, "id", None),
getattr(form_model, "workflow_run_id", None),
)

View File

@ -1,15 +1,13 @@
import logging
import time
from collections.abc import Mapping
from typing import Any
import click
from celery import shared_task
from configs import dify_config
from extensions.ext_mail import mail
from libs.email_template_renderer import render_email_template
from libs.email_i18n import get_email_i18n_service
from libs.email_template_renderer import render_email_template
logger = logging.getLogger(__name__)

View File

@ -1,5 +1,5 @@
from datetime import UTC, datetime
import uuid
from datetime import UTC, datetime
from unittest.mock import patch
import pytest

View File

@ -4,14 +4,17 @@ from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from core.workflow.nodes.human_input.entities import HumanInputFormStatus
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
HumanInputFormEntity,
HumanInputFormRecipientEntity,
HumanInputFormRepository,
)
from libs.datetime_utils import naive_utc_now
class _InMemoryFormRecipient(HumanInputFormRecipientEntity):
@ -38,6 +41,8 @@ class _InMemoryFormEntity(HumanInputFormEntity):
action_id: str | None = None
data: Mapping[str, Any] | None = None
is_submitted: bool = False
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
expiration: datetime = naive_utc_now()
@property
def id(self) -> str:
@ -67,6 +72,14 @@ class _InMemoryFormEntity(HumanInputFormEntity):
def submitted(self) -> bool:
return self.is_submitted
@property
def status(self) -> HumanInputFormStatus:
return self.status_value
@property
def expiration_time(self) -> datetime:
return self.expiration
class InMemoryHumanInputFormRepository(HumanInputFormRepository):
"""Pure in-memory repository used by workflow graph engine tests."""
@ -100,6 +113,7 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
entity.action_id = action_id
entity.data = form_data or {}
entity.is_submitted = True
entity.status_value = HumanInputFormStatus.SUBMITTED
def clear_submission(self) -> None:
if not self.created_forms:
@ -108,3 +122,4 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
form.action_id = None
form.data = None
form.is_submitted = False
form.status_value = HumanInputFormStatus.WAITING

View File

@ -1,5 +1,3 @@
import pytest
from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients

View File

@ -1,5 +1,5 @@
import dataclasses
from datetime import datetime
from datetime import datetime, timedelta
from unittest.mock import MagicMock
import pytest
@ -8,7 +8,14 @@ from core.repositories.human_input_reposotiry import (
HumanInputFormRecord,
HumanInputFormSubmissionRepository,
)
from core.workflow.nodes.human_input.entities import FormDefinition, FormInput, FormInputType, TimeoutUnit, UserAction
from core.workflow.nodes.human_input.entities import (
FormDefinition,
FormInput,
FormInputType,
HumanInputFormStatus,
TimeoutUnit,
UserAction,
)
from models.account import Account
from models.human_input import RecipientType
from services.human_input_service import FormSubmittedError, HumanInputService, InvalidFormDataError
@ -42,7 +49,8 @@ def sample_form_record():
timeout_unit=TimeoutUnit.HOUR,
),
rendered_content="<p>hello</p>",
expiration_time=datetime(2024, 1, 1),
expiration_time=datetime.utcnow() + timedelta(hours=1),
status=HumanInputFormStatus.WAITING,
selected_action_id=None,
submitted_data=None,
submitted_at=None,

View File

@ -1,4 +1,3 @@
import types
from collections.abc import Sequence
import pytest