mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
WIP: human input timeout
This commit is contained in:
parent
203a3a68af
commit
5d0dd329f2
@ -82,6 +82,12 @@ class AppExecutionConfig(BaseSettings):
|
||||
default=0,
|
||||
)
|
||||
|
||||
HITL_GLOBAL_TIMEOUT_HOURS: PositiveInt = Field(
|
||||
description="Maximum hours a workflow run can stay paused waiting for human input before global timeout.",
|
||||
default=24 * 7,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
|
||||
class CodeExecutionSandboxConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@ -28,6 +28,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
@ -55,7 +56,6 @@ from core.workflow.graph_events import (
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_events.graph import GraphRunAbortedEvent
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
@ -63,9 +63,9 @@ from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow
|
||||
from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from core.workflow.nodes.human_input.entities import (
|
||||
EmailRecipient,
|
||||
ExternalRecipient,
|
||||
FormDefinition,
|
||||
HumanInputFormStatus,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
WebAppDeliveryMethod,
|
||||
@ -106,6 +107,14 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity):
|
||||
def submitted(self) -> bool:
|
||||
return self._form_model.submitted_at is not None
|
||||
|
||||
@property
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
return self._form_model.status
|
||||
|
||||
@property
|
||||
def expiration_time(self) -> datetime:
|
||||
return self._form_model.expiration_time
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class HumanInputFormRecord:
|
||||
@ -116,6 +125,7 @@ class HumanInputFormRecord:
|
||||
definition: FormDefinition
|
||||
rendered_content: str
|
||||
expiration_time: datetime
|
||||
status: HumanInputFormStatus
|
||||
selected_action_id: str | None
|
||||
submitted_data: Mapping[str, Any] | None
|
||||
submitted_at: datetime | None
|
||||
@ -142,6 +152,7 @@ class HumanInputFormRecord:
|
||||
definition=FormDefinition.model_validate_json(form_model.form_definition),
|
||||
rendered_content=form_model.rendered_content,
|
||||
expiration_time=form_model.expiration_time,
|
||||
status=form_model.status,
|
||||
selected_action_id=form_model.selected_action_id,
|
||||
submitted_data=json.loads(form_model.submitted_data) if form_model.submitted_data else None,
|
||||
submitted_at=form_model.submitted_at,
|
||||
@ -296,6 +307,8 @@ class HumanInputFormRepositoryImpl:
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
# Generate unique form ID
|
||||
form_id = str(uuidv7())
|
||||
start_time = naive_utc_now()
|
||||
node_expiration = form_config.expiration_time(start_time)
|
||||
form_definition = FormDefinition(
|
||||
form_content=form_config.form_content,
|
||||
inputs=form_config.inputs,
|
||||
@ -312,7 +325,8 @@ class HumanInputFormRepositoryImpl:
|
||||
node_id=params.node_id,
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content=params.rendered_content,
|
||||
expiration_time=form_config.expiration_time(naive_utc_now()),
|
||||
expiration_time=node_expiration,
|
||||
created_at=start_time,
|
||||
)
|
||||
session.add(form_model)
|
||||
recipient_models: list[HumanInputFormRecipient] = []
|
||||
@ -404,6 +418,7 @@ class HumanInputFormSubmissionRepository:
|
||||
form_model.selected_action_id = selected_action_id
|
||||
form_model.submitted_data = json.dumps(form_data)
|
||||
form_model.submitted_at = naive_utc_now()
|
||||
form_model.status = HumanInputFormStatus.SUBMITTED
|
||||
form_model.submission_user_id = submission_user_id
|
||||
form_model.submission_end_user_id = submission_end_user_id
|
||||
form_model.completed_by_recipient_id = recipient_id
|
||||
@ -415,3 +430,29 @@ class HumanInputFormSubmissionRepository:
|
||||
session.refresh(recipient_model)
|
||||
|
||||
return HumanInputFormRecord.from_models(form_model, recipient_model)
|
||||
|
||||
def mark_timeout(self, *, form_id: str, reason: str | None = None) -> HumanInputFormRecord:
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
form_model = session.get(HumanInputForm, form_id)
|
||||
if form_model is None:
|
||||
raise FormNotFoundError(f"form not found, id={form_id}")
|
||||
|
||||
# already handled or submitted
|
||||
if form_model.status == HumanInputFormStatus.TIMEOUT:
|
||||
return HumanInputFormRecord.from_models(form_model, None)
|
||||
|
||||
if form_model.submitted_at is not None or form_model.status == HumanInputFormStatus.SUBMITTED:
|
||||
raise FormNotFoundError(f"form already submitted, id={form_id}")
|
||||
|
||||
form_model.status = HumanInputFormStatus.TIMEOUT
|
||||
form_model.selected_action_id = None
|
||||
form_model.submitted_data = None
|
||||
form_model.submission_user_id = None
|
||||
form_model.submission_end_user_id = None
|
||||
form_model.completed_by_recipient_id = None
|
||||
# Reason is recorded in status/error downstream; not stored on form.
|
||||
session.add(form_model)
|
||||
session.flush()
|
||||
session.refresh(form_model)
|
||||
|
||||
return HumanInputFormRecord.from_models(form_model, None)
|
||||
|
||||
@ -15,8 +15,9 @@ from core.workflow.repositories.human_input_form_repository import (
|
||||
)
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .entities import HumanInputNodeData, PlaceholderType
|
||||
from .entities import HumanInputFormStatus, HumanInputNodeData, PlaceholderType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
@ -221,6 +222,14 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
edge_source_handle=selected_action_id,
|
||||
)
|
||||
|
||||
if form.status == HumanInputFormStatus.TIMEOUT or form.expiration_time <= naive_utc_now():
|
||||
outputs: dict[str, Any] = {"__rendered_content": form.rendered_content}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
edge_source_handle="__timeout",
|
||||
)
|
||||
|
||||
return self._pause_with_form(form)
|
||||
|
||||
def _pause_with_form(self, form_entity: HumanInputFormEntity) -> Generator[NodeEventBase, None, None]:
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import abc
|
||||
import dataclasses
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputFormStatus, HumanInputNodeData
|
||||
|
||||
|
||||
class HumanInputError(Exception):
|
||||
@ -82,6 +83,18 @@ class HumanInputFormEntity(abc.ABC):
|
||||
"""Whether the form has been submitted."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
"""Current status of the form."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def expiration_time(self) -> datetime:
|
||||
"""When the form expires."""
|
||||
...
|
||||
|
||||
|
||||
class HumanInputFormRecipientEntity(abc.ABC):
|
||||
@property
|
||||
|
||||
@ -9,7 +9,8 @@ from core.repositories.human_input_reposotiry import (
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormSubmissionRepository,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition, HumanInputFormStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.exception import BaseHTTPException
|
||||
from models.account import Account
|
||||
from models.human_input import RecipientType
|
||||
@ -48,6 +49,14 @@ class Form:
|
||||
def recipient_type(self) -> RecipientType | None:
|
||||
return self._record.recipient_type
|
||||
|
||||
@property
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
return self._record.status
|
||||
|
||||
@property
|
||||
def expiration_time(self):
|
||||
return self._record.expiration_time
|
||||
|
||||
|
||||
class HumanInputError(Exception):
|
||||
pass
|
||||
@ -80,6 +89,14 @@ class WebAppDeliveryNotEnabledError(HumanInputError, BaseException):
|
||||
pass
|
||||
|
||||
|
||||
class FormExpiredError(HumanInputError, BaseHTTPException):
|
||||
error_code = "human_input_form_expired"
|
||||
code = 412
|
||||
|
||||
def __init__(self, form_id: str):
|
||||
super().__init__(description=f"This form has expired, form_id={form_id}")
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -134,7 +151,7 @@ class HumanInputService:
|
||||
if form is None:
|
||||
raise WebAppDeliveryNotEnabledError()
|
||||
|
||||
self._ensure_not_submitted(form)
|
||||
self._ensure_form_active(form)
|
||||
self._validate_submission(form=form, selected_action_id=selected_action_id, form_data=form_data)
|
||||
|
||||
result = self._form_repository.mark_submitted(
|
||||
@ -160,7 +177,7 @@ class HumanInputService:
|
||||
if form is None or form.recipient_type != recipient_type:
|
||||
raise WebAppDeliveryNotEnabledError()
|
||||
|
||||
self._ensure_not_submitted(form)
|
||||
self._ensure_form_active(form)
|
||||
self._validate_submission(form=form, selected_action_id=selected_action_id, form_data=form_data)
|
||||
|
||||
result = self._form_repository.mark_submitted(
|
||||
@ -174,6 +191,15 @@ class HumanInputService:
|
||||
|
||||
self._enqueue_resume(result.workflow_run_id)
|
||||
|
||||
def _ensure_form_active(self, form: Form) -> None:
|
||||
if form.submitted:
|
||||
raise FormSubmittedError(form.id)
|
||||
if form.status == HumanInputFormStatus.TIMEOUT:
|
||||
raise FormExpiredError(form.id)
|
||||
now = naive_utc_now()
|
||||
if form.expiration_time <= now:
|
||||
raise FormExpiredError(form.id)
|
||||
|
||||
def _ensure_not_submitted(self, form: Form) -> None:
|
||||
if form.submitted:
|
||||
raise FormSubmittedError(form.id)
|
||||
|
||||
110
api/tasks/human_input_timeout_tasks.py
Normal file
110
api/tasks/human_input_timeout_tasks.py
Normal file
@ -0,0 +1,110 @@
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.repositories.human_input_reposotiry import HumanInputFormSubmissionRepository
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition, HumanInputFormStatus, TimeoutUnit
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import ensure_naive_utc, naive_utc_now
|
||||
from models.human_input import HumanInputForm
|
||||
from models.workflow import WorkflowPause, WorkflowRun
|
||||
from services.human_input_service import HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _calculate_node_deadline(definition: FormDefinition, created_at, *, start_time=None):
|
||||
start = start_time or created_at
|
||||
if definition.timeout_unit == TimeoutUnit.HOUR:
|
||||
return start + timedelta(hours=definition.timeout)
|
||||
if definition.timeout_unit == TimeoutUnit.DAY:
|
||||
return start + timedelta(days=definition.timeout)
|
||||
raise AssertionError("unknown timeout unit.")
|
||||
|
||||
|
||||
def _is_global_timeout(form_model: HumanInputForm, global_timeout_hours: int) -> bool:
|
||||
if global_timeout_hours <= 0:
|
||||
return False
|
||||
|
||||
form_definition = FormDefinition.model_validate_json(form_model.form_definition)
|
||||
|
||||
created_at = ensure_naive_utc(form_model.created_at)
|
||||
expiration_time = ensure_naive_utc(form_model.expiration_time)
|
||||
node_deadline = _calculate_node_deadline(form_definition, created_at)
|
||||
global_deadline = created_at + timedelta(hours=global_timeout_hours)
|
||||
return global_deadline <= node_deadline and expiration_time <= global_deadline
|
||||
|
||||
|
||||
def _handle_global_timeout(*, form_id: str, workflow_run_id: str, node_id: str, session_factory: sessionmaker) -> None:
|
||||
now = naive_utc_now()
|
||||
with session_factory() as session, session.begin():
|
||||
workflow_run = session.get(WorkflowRun, workflow_run_id)
|
||||
if workflow_run is not None:
|
||||
workflow_run.status = WorkflowExecutionStatus.FAILED
|
||||
workflow_run.error = f"Human input global timeout at node {node_id}"
|
||||
workflow_run.finished_at = now
|
||||
session.add(workflow_run)
|
||||
|
||||
pause_model = session.scalar(select(WorkflowPause).where(WorkflowPause.workflow_run_id == workflow_run_id))
|
||||
if pause_model is not None:
|
||||
try:
|
||||
storage.delete(pause_model.state_object_key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to delete pause state object for workflow_run_id=%s, pause_id=%s",
|
||||
workflow_run_id,
|
||||
pause_model.id,
|
||||
)
|
||||
pause_model.resumed_at = now
|
||||
session.add(pause_model)
|
||||
|
||||
|
||||
@shared_task(name="human_input_form_timeout.check_and_resume")
|
||||
def check_and_handle_human_input_timeouts(limit: int = 100) -> None:
|
||||
"""Scan for expired human input forms and resume or end workflows."""
|
||||
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
form_repo = HumanInputFormSubmissionRepository(session_factory)
|
||||
service = HumanInputService(session_factory, form_repository=form_repo)
|
||||
now = naive_utc_now()
|
||||
global_timeout_hours = int(getattr(dify_config, "HITL_GLOBAL_TIMEOUT_HOURS", 0) or 0)
|
||||
|
||||
with session_factory() as session:
|
||||
stmt = (
|
||||
select(HumanInputForm)
|
||||
.where(
|
||||
HumanInputForm.status == HumanInputFormStatus.WAITING,
|
||||
HumanInputForm.expiration_time <= now,
|
||||
)
|
||||
.limit(limit)
|
||||
)
|
||||
expired_forms = session.scalars(stmt).all()
|
||||
|
||||
for form_model in expired_forms:
|
||||
try:
|
||||
is_global = _is_global_timeout(form_model, global_timeout_hours)
|
||||
record = form_repo.mark_timeout(
|
||||
form_id=form_model.id,
|
||||
reason="global_timeout" if is_global else "node_timeout",
|
||||
)
|
||||
if is_global:
|
||||
_handle_global_timeout(
|
||||
form_id=record.form_id,
|
||||
workflow_run_id=record.workflow_run_id,
|
||||
node_id=record.node_id,
|
||||
session_factory=session_factory,
|
||||
)
|
||||
else:
|
||||
service._enqueue_resume(record.workflow_run_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to handle timeout for form_id=%s workflow_run_id=%s",
|
||||
getattr(form_model, "id", None),
|
||||
getattr(form_model, "workflow_run_id", None),
|
||||
)
|
||||
@ -1,15 +1,13 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_mail import mail
|
||||
from libs.email_template_renderer import render_email_template
|
||||
from libs.email_i18n import get_email_i18n_service
|
||||
from libs.email_template_renderer import render_email_template
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from datetime import UTC, datetime
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -4,14 +4,17 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.nodes.human_input.entities import HumanInputFormStatus
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRecipientEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class _InMemoryFormRecipient(HumanInputFormRecipientEntity):
|
||||
@ -38,6 +41,8 @@ class _InMemoryFormEntity(HumanInputFormEntity):
|
||||
action_id: str | None = None
|
||||
data: Mapping[str, Any] | None = None
|
||||
is_submitted: bool = False
|
||||
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
|
||||
expiration: datetime = naive_utc_now()
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
@ -67,6 +72,14 @@ class _InMemoryFormEntity(HumanInputFormEntity):
|
||||
def submitted(self) -> bool:
|
||||
return self.is_submitted
|
||||
|
||||
@property
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
return self.status_value
|
||||
|
||||
@property
|
||||
def expiration_time(self) -> datetime:
|
||||
return self.expiration
|
||||
|
||||
|
||||
class InMemoryHumanInputFormRepository(HumanInputFormRepository):
|
||||
"""Pure in-memory repository used by workflow graph engine tests."""
|
||||
@ -100,6 +113,7 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
|
||||
entity.action_id = action_id
|
||||
entity.data = form_data or {}
|
||||
entity.is_submitted = True
|
||||
entity.status_value = HumanInputFormStatus.SUBMITTED
|
||||
|
||||
def clear_submission(self) -> None:
|
||||
if not self.created_forms:
|
||||
@ -108,3 +122,4 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
|
||||
form.action_id = None
|
||||
form.data = None
|
||||
form.is_submitted = False
|
||||
form.status_value = HumanInputFormStatus.WAITING
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
import pytest
|
||||
|
||||
from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import dataclasses
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -8,7 +8,14 @@ from core.repositories.human_input_reposotiry import (
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormSubmissionRepository,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition, FormInput, FormInputType, TimeoutUnit, UserAction
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
FormDefinition,
|
||||
FormInput,
|
||||
FormInputType,
|
||||
HumanInputFormStatus,
|
||||
TimeoutUnit,
|
||||
UserAction,
|
||||
)
|
||||
from models.account import Account
|
||||
from models.human_input import RecipientType
|
||||
from services.human_input_service import FormSubmittedError, HumanInputService, InvalidFormDataError
|
||||
@ -42,7 +49,8 @@ def sample_form_record():
|
||||
timeout_unit=TimeoutUnit.HOUR,
|
||||
),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=datetime(2024, 1, 1),
|
||||
expiration_time=datetime.utcnow() + timedelta(hours=1),
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
selected_action_id=None,
|
||||
submitted_data=None,
|
||||
submitted_at=None,
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import types
|
||||
from collections.abc import Sequence
|
||||
|
||||
import pytest
|
||||
|
||||
Loading…
Reference in New Issue
Block a user