diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index ac7f3a6b48..704d896192 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -8,7 +8,6 @@ on: - "build/**" - "release/e-*" - "hotfix/**" - - "feat/hitl-backend" tags: - "*" diff --git a/api/.env.example b/api/.env.example index fcadfa1c3b..8bd2c706c1 100644 --- a/api/.env.example +++ b/api/.env.example @@ -717,28 +717,3 @@ SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000 - - -# Redis URL used for PubSub between API and -# celery worker -# defaults to url constructed from `REDIS_*` -# configurations -PUBSUB_REDIS_URL= -# Pub/sub channel type for streaming events. -# valid options are: -# -# - pubsub: for normal Pub/Sub -# - sharded: for sharded Pub/Sub -# -# It's highly recommended to use sharded Pub/Sub AND redis cluster -# for large deployments. -PUBSUB_REDIS_CHANNEL_TYPE=pubsub -# Whether to use Redis cluster mode while running -# PubSub. -# It's highly recommended to enable this for large deployments. -PUBSUB_REDIS_USE_CLUSTERS=false - -# Whether to Enable human input timeout check task -ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true -# Human input timeout check interval in minutes -HUMAN_INPUT_TIMEOUT_TASK_INTERVAL=1 diff --git a/api/.importlinter b/api/.importlinter index 98f87710ed..9dad254560 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -36,8 +36,6 @@ ignore_imports = core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine core.workflow.nodes.loop.loop_node -> core.workflow.graph core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels - # TODO(QuantumGhost): fix the import violation later - core.workflow.entities.pause_reason -> core.workflow.nodes.human_input.entities [importlinter:contract:workflow-infrastructure-dependencies] name = Workflow Infrastructure Dependencies @@ -60,8 +58,6 @@ ignore_imports = core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis core.workflow.graph_engine.manager -> extensions.ext_redis core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis - # TODO(QuantumGhost): use DI to avoid depending on global DB. - core.workflow.nodes.human_input.human_input_node -> extensions.ext_database [importlinter:contract:workflow-external-imports] name = Workflow External Imports @@ -149,7 +145,6 @@ ignore_imports = core.workflow.nodes.agent.agent_node -> core.agent.entities core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities - core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities @@ -253,7 +248,6 @@ ignore_imports = core.workflow.nodes.document_extractor.node -> core.variables.segments core.workflow.nodes.http_request.executor -> core.variables.segments core.workflow.nodes.http_request.node -> core.variables.segments - core.workflow.nodes.human_input.entities -> core.variables.consts core.workflow.nodes.iteration.iteration_node -> core.variables core.workflow.nodes.iteration.iteration_node -> core.variables.segments core.workflow.nodes.iteration.iteration_node -> core.variables.variables @@ -300,8 +294,6 @@ ignore_imports = core.workflow.nodes.llm.llm_utils -> extensions.ext_database core.workflow.nodes.llm.node -> extensions.ext_database core.workflow.nodes.tool.tool_node -> extensions.ext_database - core.workflow.nodes.human_input.human_input_node -> extensions.ext_database - core.workflow.nodes.human_input.human_input_node -> core.repositories.human_input_repository core.workflow.workflow_entry -> extensions.otel.runtime core.workflow.nodes.agent.agent_node -> models core.workflow.nodes.base.node -> models.enums diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index c405d5d44c..d97e9a0440 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1,4 +1,3 @@ -from datetime import timedelta from enum import StrEnum from typing import Literal @@ -49,16 +48,6 @@ class SecurityConfig(BaseSettings): default=5, ) - WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS: PositiveInt = Field( - description="Maximum number of web form submissions allowed per IP within the rate limit window", - default=30, - ) - - WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS: PositiveInt = Field( - description="Time window in seconds for web form submission rate limiting", - default=60, - ) - LOGIN_DISABLED: bool = Field( description="Whether to disable login checks", default=False, @@ -93,12 +82,6 @@ class AppExecutionConfig(BaseSettings): default=0, ) - HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS: PositiveInt = Field( - description="Maximum seconds a workflow run can stay paused waiting for human input before global timeout.", - default=int(timedelta(days=7).total_seconds()), - ge=1, - ) - class CodeExecutionSandboxConfig(BaseSettings): """ @@ -1151,14 +1134,6 @@ class CeleryScheduleTasksConfig(BaseSettings): description="Enable queue monitor task", default=False, ) - ENABLE_HUMAN_INPUT_TIMEOUT_TASK: bool = Field( - description="Enable human input timeout check task", - default=True, - ) - HUMAN_INPUT_TIMEOUT_TASK_INTERVAL: PositiveInt = Field( - description="Human input timeout check interval in minutes", - default=1, - ) ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field( description="Enable check upgradable plugin task", default=True, diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index a15e42babf..63f75924bf 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -6,7 +6,6 @@ from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, Pos from pydantic_settings import BaseSettings from .cache.redis_config import RedisConfig -from .cache.redis_pubsub_config import RedisPubSubConfig from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig from .storage.amazon_s3_storage_config import S3StorageConfig from .storage.azure_blob_storage_config import AzureBlobStorageConfig @@ -318,7 +317,6 @@ class MiddlewareConfig( CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig KeywordStoreConfig, RedisConfig, - RedisPubSubConfig, # configs of storage and storage providers StorageConfig, AliyunOSSStorageConfig, diff --git a/api/configs/middleware/cache/redis_pubsub_config.py b/api/configs/middleware/cache/redis_pubsub_config.py deleted file mode 100644 index a72e1dd28f..0000000000 --- a/api/configs/middleware/cache/redis_pubsub_config.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import Literal, Protocol -from urllib.parse import quote_plus, urlunparse - -from pydantic import Field -from pydantic_settings import BaseSettings - - -class RedisConfigDefaults(Protocol): - REDIS_HOST: str - REDIS_PORT: int - REDIS_USERNAME: str | None - REDIS_PASSWORD: str | None - REDIS_DB: int - REDIS_USE_SSL: bool - REDIS_USE_SENTINEL: bool | None - REDIS_USE_CLUSTERS: bool - - -class RedisConfigDefaultsMixin: - def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults: - return self - - -class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin): - """ - Configuration settings for Redis pub/sub streaming. - """ - - PUBSUB_REDIS_URL: str | None = Field( - alias="PUBSUB_REDIS_URL", - description=( - "Redis connection URL for pub/sub streaming events between API " - "and celery worker, defaults to url constructed from " - "`REDIS_*` configurations" - ), - default=None, - ) - - PUBSUB_REDIS_USE_CLUSTERS: bool = Field( - description=( - "Enable Redis Cluster mode for pub/sub streaming. It's highly " - "recommended to enable this for large deployments." - ), - default=False, - ) - - PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field( - description=( - "Pub/sub channel type for streaming events. " - "Valid options are:\n" - "\n" - " - pubsub: for normal Pub/Sub\n" - " - sharded: for sharded Pub/Sub\n" - "\n" - "It's highly recommended to use sharded Pub/Sub AND redis cluster " - "for large deployments." - ), - default="pubsub", - ) - - def _build_default_pubsub_url(self) -> str: - defaults = self._redis_defaults() - if not defaults.REDIS_HOST or not defaults.REDIS_PORT: - raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed") - - scheme = "rediss" if defaults.REDIS_USE_SSL else "redis" - username = defaults.REDIS_USERNAME or None - password = defaults.REDIS_PASSWORD or None - - userinfo = "" - if username: - userinfo = quote_plus(username) - if password: - password_part = quote_plus(password) - userinfo = f"{userinfo}:{password_part}" if userinfo else f":{password_part}" - if userinfo: - userinfo = f"{userinfo}@" - - host = defaults.REDIS_HOST - port = defaults.REDIS_PORT - db = defaults.REDIS_DB - - netloc = f"{userinfo}{host}:{port}" - return urlunparse((scheme, netloc, f"/{db}", "", "", "")) - - @property - def normalized_pubsub_redis_url(self) -> str: - pubsub_redis_url = self.PUBSUB_REDIS_URL - if pubsub_redis_url: - cleaned = pubsub_redis_url.strip() - pubsub_redis_url = cleaned or None - - if pubsub_redis_url: - return pubsub_redis_url - - return self._build_default_pubsub_url() diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 902d67174b..fdc9aabc83 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -37,7 +37,6 @@ from . import ( apikey, extension, feature, - human_input_form, init_validate, ping, setup, @@ -172,7 +171,6 @@ __all__ = [ "forgot_password", "generator", "hit_testing", - "human_input_form", "init_validate", "installed_app", "load_balancing_config", diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 14910c5895..55fdcb51e4 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -89,7 +89,6 @@ status_count_model = console_ns.model( "success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer, - "paused": fields.Integer, }, ) diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index ab1628d5d4..12ada8b798 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -32,7 +32,7 @@ from libs.login import current_account_with_tenant, login_required from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError -from services.message_service import MessageService, attach_message_extra_contents +from services.message_service import MessageService logger = logging.getLogger(__name__) DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -198,7 +198,6 @@ message_detail_model = console_ns.model( "created_at": TimestampField, "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), "message_files": fields.List(fields.Nested(message_file_model)), - "extra_contents": fields.List(fields.Raw), "metadata": fields.Raw(attribute="message_metadata_dict"), "status": fields.String, "error": fields.String, @@ -291,7 +290,6 @@ class ChatMessageListApi(Resource): has_more = False history_messages = list(reversed(history_messages)) - attach_message_extra_contents(history_messages) return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more) @@ -476,5 +474,4 @@ class MessageApi(Resource): if not message: raise NotFound("Message Not Exists.") - attach_message_extra_contents([message]) return message diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 27e1d01af6..755463cb70 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -507,179 +507,6 @@ class WorkflowDraftRunLoopNodeApi(Resource): raise InternalServerError() -class HumanInputFormPreviewPayload(BaseModel): - inputs: dict[str, Any] = Field( - default_factory=dict, - description="Values used to fill missing upstream variables referenced in form_content", - ) - - -class HumanInputFormSubmitPayload(BaseModel): - form_inputs: dict[str, Any] = Field(..., description="Values the user provides for the form's own fields") - inputs: dict[str, Any] = Field( - ..., - description="Values used to fill missing upstream variables referenced in form_content", - ) - action: str = Field(..., description="Selected action ID") - - -class HumanInputDeliveryTestPayload(BaseModel): - delivery_method_id: str = Field(..., description="Delivery method ID") - inputs: dict[str, Any] = Field( - default_factory=dict, - description="Values used to fill missing upstream variables referenced in form_content", - ) - - -reg(HumanInputFormPreviewPayload) -reg(HumanInputFormSubmitPayload) -reg(HumanInputDeliveryTestPayload) - - -@console_ns.route("/apps//advanced-chat/workflows/draft/human-input/nodes//form/preview") -class AdvancedChatDraftHumanInputFormPreviewApi(Resource): - @console_ns.doc("get_advanced_chat_draft_human_input_form") - @console_ns.doc(description="Get human input form preview for advanced chat workflow") - @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) - @console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__]) - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=[AppMode.ADVANCED_CHAT]) - @edit_permission_required - def post(self, app_model: App, node_id: str): - """ - Preview human input form content and placeholders - """ - current_user, _ = current_account_with_tenant() - args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {}) - inputs = args.inputs - - workflow_service = WorkflowService() - preview = workflow_service.get_human_input_form_preview( - app_model=app_model, - account=current_user, - node_id=node_id, - inputs=inputs, - ) - return jsonable_encoder(preview) - - -@console_ns.route("/apps//advanced-chat/workflows/draft/human-input/nodes//form/run") -class AdvancedChatDraftHumanInputFormRunApi(Resource): - @console_ns.doc("submit_advanced_chat_draft_human_input_form") - @console_ns.doc(description="Submit human input form preview for advanced chat workflow") - @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) - @console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__]) - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=[AppMode.ADVANCED_CHAT]) - @edit_permission_required - def post(self, app_model: App, node_id: str): - """ - Submit human input form preview - """ - current_user, _ = current_account_with_tenant() - args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {}) - workflow_service = WorkflowService() - result = workflow_service.submit_human_input_form_preview( - app_model=app_model, - account=current_user, - node_id=node_id, - form_inputs=args.form_inputs, - inputs=args.inputs, - action=args.action, - ) - return jsonable_encoder(result) - - -@console_ns.route("/apps//workflows/draft/human-input/nodes//form/preview") -class WorkflowDraftHumanInputFormPreviewApi(Resource): - @console_ns.doc("get_workflow_draft_human_input_form") - @console_ns.doc(description="Get human input form preview for workflow") - @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) - @console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__]) - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=[AppMode.WORKFLOW]) - @edit_permission_required - def post(self, app_model: App, node_id: str): - """ - Preview human input form content and placeholders - """ - current_user, _ = current_account_with_tenant() - args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {}) - inputs = args.inputs - - workflow_service = WorkflowService() - preview = workflow_service.get_human_input_form_preview( - app_model=app_model, - account=current_user, - node_id=node_id, - inputs=inputs, - ) - return jsonable_encoder(preview) - - -@console_ns.route("/apps//workflows/draft/human-input/nodes//form/run") -class WorkflowDraftHumanInputFormRunApi(Resource): - @console_ns.doc("submit_workflow_draft_human_input_form") - @console_ns.doc(description="Submit human input form preview for workflow") - @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) - @console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__]) - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=[AppMode.WORKFLOW]) - @edit_permission_required - def post(self, app_model: App, node_id: str): - """ - Submit human input form preview - """ - current_user, _ = current_account_with_tenant() - workflow_service = WorkflowService() - args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {}) - result = workflow_service.submit_human_input_form_preview( - app_model=app_model, - account=current_user, - node_id=node_id, - form_inputs=args.form_inputs, - inputs=args.inputs, - action=args.action, - ) - return jsonable_encoder(result) - - -@console_ns.route("/apps//workflows/draft/human-input/nodes//delivery-test") -class WorkflowDraftHumanInputDeliveryTestApi(Resource): - @console_ns.doc("test_workflow_draft_human_input_delivery") - @console_ns.doc(description="Test human input delivery for workflow") - @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) - @console_ns.expect(console_ns.models[HumanInputDeliveryTestPayload.__name__]) - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]) - @edit_permission_required - def post(self, app_model: App, node_id: str): - """ - Test human input delivery - """ - current_user, _ = current_account_with_tenant() - workflow_service = WorkflowService() - args = HumanInputDeliveryTestPayload.model_validate(console_ns.payload or {}) - workflow_service.test_human_input_delivery( - app_model=app_model, - account=current_user, - node_id=node_id, - delivery_method_id=args.delivery_method_id, - inputs=args.inputs, - ) - return jsonable_encoder({}) - - @console_ns.route("/apps//workflows/draft/run") class DraftWorkflowRunApi(Resource): @console_ns.doc("run_draft_workflow") diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index d9a5dde55a..fa74f8aea1 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -5,15 +5,10 @@ from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator from sqlalchemy import select -from sqlalchemy.orm import sessionmaker -from configs import dify_config from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required -from controllers.web.error import NotFoundError -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import WorkflowExecutionStatus from extensions.ext_database import db from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields @@ -32,21 +27,9 @@ from libs.custom_inputs import time_duration from libs.helper import uuid_value from libs.login import current_user, login_required from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom -from models.workflow import WorkflowRun -from repositories.factory import DifyAPIRepositoryFactory from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME from services.workflow_run_service import WorkflowRunService - -def _build_backstage_input_url(form_token: str | None) -> str | None: - if not form_token: - return None - base_url = dify_config.APP_WEB_URL - if not base_url: - return None - return f"{base_url.rstrip('/')}/form/{form_token}" - - # Workflow run status choices for filtering WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"] EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600 @@ -457,63 +440,3 @@ class WorkflowRunNodeExecutionListApi(Resource): ) return {"data": node_executions} - - -@console_ns.route("/workflow//pause-details") -class ConsoleWorkflowPauseDetailsApi(Resource): - """Console API for getting workflow pause details.""" - - @account_initialization_required - @login_required - def get(self, workflow_run_id: str): - """ - Get workflow pause details. - - GET /console/api/workflow//pause-details - - Returns information about why and where the workflow is paused. - """ - - # Query WorkflowRun to determine if workflow is suspended - session_maker = sessionmaker(bind=db.engine) - workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_maker) - workflow_run = db.session.get(WorkflowRun, workflow_run_id) - if not workflow_run: - raise NotFoundError("Workflow run not found") - - # Check if workflow is suspended - is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED - if not is_paused: - return { - "paused_at": None, - "paused_nodes": [], - }, 200 - - pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) - pause_reasons = pause_entity.get_pause_reasons() if pause_entity else [] - - # Build response - paused_at = pause_entity.paused_at if pause_entity else None - paused_nodes = [] - response = { - "paused_at": paused_at.isoformat() + "Z" if paused_at else None, - "paused_nodes": paused_nodes, - } - - for reason in pause_reasons: - if isinstance(reason, HumanInputRequired): - paused_nodes.append( - { - "node_id": reason.node_id, - "node_title": reason.node_title, - "pause_type": { - "type": "human_input", - "form_id": reason.form_id, - "backstage_input_url": _build_backstage_input_url(reason.form_token), - }, - } - ) - else: - raise AssertionError("unimplemented.") - - return response, 200 diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py deleted file mode 100644 index 7207f7fd1d..0000000000 --- a/api/controllers/console/human_input_form.py +++ /dev/null @@ -1,217 +0,0 @@ -""" -Console/Studio Human Input Form APIs. -""" - -import json -import logging -from collections.abc import Generator - -from flask import Response, jsonify, request -from flask_restx import Resource, reqparse -from sqlalchemy import select -from sqlalchemy.orm import Session, sessionmaker - -from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, setup_required -from controllers.web.error import InvalidArgumentError, NotFoundError -from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from core.app.apps.message_generator import MessageGenerator -from core.app.apps.workflow.app_generator import WorkflowAppGenerator -from extensions.ext_database import db -from libs.login import current_account_with_tenant, login_required -from models import App -from models.enums import CreatorUserRole -from models.human_input import RecipientType -from models.model import AppMode -from models.workflow import WorkflowRun -from repositories.factory import DifyAPIRepositoryFactory -from services.human_input_service import Form, HumanInputService -from services.workflow_event_snapshot_service import build_workflow_event_stream - -logger = logging.getLogger(__name__) - - -def _jsonify_form_definition(form: Form) -> Response: - payload = form.get_definition().model_dump() - payload["expiration_time"] = int(form.expiration_time.timestamp()) - return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json") - - -@console_ns.route("/form/human_input/") -class ConsoleHumanInputFormApi(Resource): - """Console API for getting human input form definition.""" - - @staticmethod - def _ensure_console_access(form: Form): - _, current_tenant_id = current_account_with_tenant() - - if form.tenant_id != current_tenant_id: - raise NotFoundError("App not found") - - @setup_required - @login_required - @account_initialization_required - def get(self, form_token: str): - """ - Get human input form definition by form token. - - GET /console/api/form/human_input/ - """ - service = HumanInputService(db.engine) - form = service.get_form_definition_by_token_for_console(form_token) - if form is None: - raise NotFoundError(f"form not found, token={form_token}") - - self._ensure_console_access(form) - - return _jsonify_form_definition(form) - - @account_initialization_required - @login_required - def post(self, form_token: str): - """ - Submit human input form by form token. - - POST /console/api/form/human_input/ - - Request body: - { - "inputs": { - "content": "User input content" - }, - "action": "Approve" - } - """ - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("action", type=str, required=True, location="json") - args = parser.parse_args() - current_user, _ = current_account_with_tenant() - - service = HumanInputService(db.engine) - form = service.get_form_by_token(form_token) - if form is None: - raise NotFoundError(f"form not found, token={form_token}") - - self._ensure_console_access(form) - - recipient_type = form.recipient_type - if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}: - raise NotFoundError(f"form not found, token={form_token}") - # The type checker is not smart enought to validate the following invariant. - # So we need to assert it manually. - assert recipient_type is not None, "recipient_type cannot be None here." - - service.submit_form_by_token( - recipient_type=recipient_type, - form_token=form_token, - selected_action_id=args["action"], - form_data=args["inputs"], - submission_user_id=current_user.id, - ) - - return jsonify({}) - - -@console_ns.route("/workflow//events") -class ConsoleWorkflowEventsApi(Resource): - """Console API for getting workflow execution events after resume.""" - - @account_initialization_required - @login_required - def get(self, workflow_run_id: str): - """ - Get workflow execution events stream after resume. - - GET /console/api/workflow//events - - Returns Server-Sent Events stream. - """ - - user, tenant_id = current_account_with_tenant() - session_maker = sessionmaker(db.engine) - repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) - workflow_run = repo.get_workflow_run_by_id_and_tenant_id( - tenant_id=tenant_id, - run_id=workflow_run_id, - ) - if workflow_run is None: - raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}") - - if workflow_run.created_by_role != CreatorUserRole.ACCOUNT: - raise NotFoundError(f"WorkflowRun not created by account, id={workflow_run_id}") - - if workflow_run.created_by != user.id: - raise NotFoundError(f"WorkflowRun not created by the current account, id={workflow_run_id}") - - with Session(expire_on_commit=False, bind=db.engine) as session: - app = _retrieve_app_for_workflow_run(session, workflow_run) - - if workflow_run.finished_at is not None: - # TODO(QuantumGhost): should we modify the handling for finished workflow run here? - response = WorkflowResponseConverter.workflow_run_result_to_finish_response( - task_id=workflow_run.id, - workflow_run=workflow_run, - creator_user=user, - ) - - payload = response.model_dump(mode="json") - payload["event"] = response.event.value - - def _generate_finished_events() -> Generator[str, None, None]: - yield f"data: {json.dumps(payload)}\n\n" - - event_generator = _generate_finished_events - - else: - msg_generator = MessageGenerator() - if app.mode == AppMode.ADVANCED_CHAT: - generator = AdvancedChatAppGenerator() - elif app.mode == AppMode.WORKFLOW: - generator = WorkflowAppGenerator() - else: - raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") - - include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true" - - def _generate_stream_events(): - if include_state_snapshot: - return generator.convert_to_event_stream( - build_workflow_event_stream( - app_mode=AppMode(app.mode), - workflow_run=workflow_run, - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - session_maker=session_maker, - ) - ) - return generator.convert_to_event_stream( - msg_generator.retrieve_events(AppMode(app.mode), workflow_run.id), - ) - - event_generator = _generate_stream_events - - return Response( - event_generator(), - mimetype="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, - ) - - -def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun): - query = select(App).where( - App.id == workflow_run.app_id, - App.tenant_id == workflow_run.tenant_id, - ) - app = session.scalars(query).first() - if app is None: - raise AssertionError( - f"App not found for WorkflowRun, workflow_run_id={workflow_run.id}, " - f"app_id={workflow_run.app_id}, tenant_id={workflow_run.tenant_id}" - ) - - return app diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 6088b142c2..6a549fc926 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -33,9 +33,8 @@ from core.workflow.graph_engine.manager import GraphEngineManager from extensions.ext_database import db from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model from libs import helper -from libs.helper import OptionalTimestampField, TimestampField +from libs.helper import TimestampField from models.model import App, AppMode, EndUser -from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError @@ -64,32 +63,17 @@ class WorkflowLogQuery(BaseModel): register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery) - -class WorkflowRunStatusField(fields.Raw): - def output(self, key, obj: WorkflowRun, **kwargs): - return obj.status.value - - -class WorkflowRunOutputsField(fields.Raw): - def output(self, key, obj: WorkflowRun, **kwargs): - if obj.status == WorkflowExecutionStatus.PAUSED: - return {} - - outputs = obj.outputs_dict - return outputs or {} - - workflow_run_fields = { "id": fields.String, "workflow_id": fields.String, - "status": WorkflowRunStatusField, + "status": fields.String, "inputs": fields.Raw, - "outputs": WorkflowRunOutputsField, + "outputs": fields.Raw, "error": fields.String, "total_steps": fields.Integer, "total_tokens": fields.Integer, "created_at": TimestampField, - "finished_at": OptionalTimestampField, + "finished_at": TimestampField, "elapsed_time": fields.Float, } diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index cfa39e0dfd..1d22954308 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -23,7 +23,6 @@ from . import ( feature, files, forgot_password, - human_input_form, login, message, passport, @@ -31,7 +30,6 @@ from . import ( saved_message, site, workflow, - workflow_events, ) api.add_namespace(web_ns) @@ -46,7 +44,6 @@ __all__ = [ "feature", "files", "forgot_password", - "human_input_form", "login", "message", "passport", @@ -55,5 +52,4 @@ __all__ = [ "site", "web_ns", "workflow", - "workflow_events", ] diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index d1f936768e..196a27e348 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -117,12 +117,6 @@ class InvokeRateLimitError(BaseHTTPException): code = 429 -class WebFormRateLimitExceededError(BaseHTTPException): - error_code = "web_form_rate_limit_exceeded" - description = "Too many form requests. Please try again later." - code = 429 - - class NotFoundError(BaseHTTPException): error_code = "not_found" code = 404 diff --git a/api/controllers/web/human_input_form.py b/api/controllers/web/human_input_form.py deleted file mode 100644 index c3989b1965..0000000000 --- a/api/controllers/web/human_input_form.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -Web App Human Input Form APIs. -""" - -import json -import logging -from datetime import datetime - -from flask import Response, request -from flask_restx import Resource, reqparse -from werkzeug.exceptions import Forbidden - -from configs import dify_config -from controllers.web import web_ns -from controllers.web.error import NotFoundError, WebFormRateLimitExceededError -from controllers.web.site import serialize_app_site_payload -from extensions.ext_database import db -from libs.helper import RateLimiter, extract_remote_ip -from models.account import TenantStatus -from models.model import App, Site -from services.human_input_service import Form, FormNotFoundError, HumanInputService - -logger = logging.getLogger(__name__) - -_FORM_SUBMIT_RATE_LIMITER = RateLimiter( - prefix="web_form_submit_rate_limit", - max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS, - time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS, -) -_FORM_ACCESS_RATE_LIMITER = RateLimiter( - prefix="web_form_access_rate_limit", - max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS, - time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS, -) - - -def _stringify_default_values(values: dict[str, object]) -> dict[str, str]: - result: dict[str, str] = {} - for key, value in values.items(): - if value is None: - result[key] = "" - elif isinstance(value, (dict, list)): - result[key] = json.dumps(value, ensure_ascii=False) - else: - result[key] = str(value) - return result - - -def _to_timestamp(value: datetime) -> int: - return int(value.timestamp()) - - -def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response: - """Return the form payload (optionally with site) as a JSON response.""" - definition_payload = form.get_definition().model_dump() - payload = { - "form_content": definition_payload["rendered_content"], - "inputs": definition_payload["inputs"], - "resolved_default_values": _stringify_default_values(definition_payload["default_values"]), - "user_actions": definition_payload["user_actions"], - "expiration_time": _to_timestamp(form.expiration_time), - } - if site_payload is not None: - payload["site"] = site_payload - return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json") - - -# TODO(QuantumGhost): disable authorization for web app -# form api temporarily - - -@web_ns.route("/form/human_input/") -# class HumanInputFormApi(WebApiResource): -class HumanInputFormApi(Resource): - """API for getting and submitting human input forms via the web app.""" - - # def get(self, _app_model: App, _end_user: EndUser, form_token: str): - def get(self, form_token: str): - """ - Get human input form definition by token. - - GET /api/form/human_input/ - """ - ip_address = extract_remote_ip(request) - if _FORM_ACCESS_RATE_LIMITER.is_rate_limited(ip_address): - raise WebFormRateLimitExceededError() - _FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address) - - service = HumanInputService(db.engine) - # TODO(QuantumGhost): forbid submision for form tokens - # that are only for console. - form = service.get_form_by_token(form_token) - - if form is None: - raise NotFoundError("Form not found") - - service.ensure_form_active(form) - app_model, site = _get_app_site_from_form(form) - - return _jsonify_form_definition(form, site_payload=serialize_app_site_payload(app_model, site, None)) - - # def post(self, _app_model: App, _end_user: EndUser, form_token: str): - def post(self, form_token: str): - """ - Submit human input form by token. - - POST /api/form/human_input/ - - Request body: - { - "inputs": { - "content": "User input content" - }, - "action": "Approve" - } - """ - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("action", type=str, required=True, location="json") - args = parser.parse_args() - - ip_address = extract_remote_ip(request) - if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address): - raise WebFormRateLimitExceededError() - _FORM_SUBMIT_RATE_LIMITER.increment_rate_limit(ip_address) - - service = HumanInputService(db.engine) - form = service.get_form_by_token(form_token) - if form is None: - raise NotFoundError("Form not found") - - if (recipient_type := form.recipient_type) is None: - logger.warning("Recipient type is None for form, form_id=%", form.id) - raise AssertionError("Recipient type is None") - - try: - service.submit_form_by_token( - recipient_type=recipient_type, - form_token=form_token, - selected_action_id=args["action"], - form_data=args["inputs"], - submission_end_user_id=None, - # submission_end_user_id=_end_user.id, - ) - except FormNotFoundError: - raise NotFoundError("Form not found") - - return {}, 200 - - -def _get_app_site_from_form(form: Form) -> tuple[App, Site]: - """Resolve App/Site for the form's app and validate tenant status.""" - app_model = db.session.query(App).where(App.id == form.app_id).first() - if app_model is None or app_model.tenant_id != form.tenant_id: - raise NotFoundError("Form not found") - - site = db.session.query(Site).where(Site.app_id == app_model.id).first() - if site is None: - raise Forbidden() - - if app_model.tenant and app_model.tenant.status == TenantStatus.ARCHIVE: - raise Forbidden() - - return app_model, site diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index f957229ece..b01aaba357 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,6 +1,4 @@ -from typing import cast - -from flask_restx import fields, marshal, marshal_with +from flask_restx import fields, marshal_with from werkzeug.exceptions import Forbidden from configs import dify_config @@ -9,7 +7,7 @@ from controllers.web.wraps import WebApiResource from extensions.ext_database import db from libs.helper import AppIconUrlField from models.account import TenantStatus -from models.model import App, Site +from models.model import Site from services.feature_service import FeatureService @@ -110,14 +108,3 @@ class AppSiteInfo: "remove_webapp_brand": remove_webapp_brand, "replace_webapp_logo": replace_webapp_logo, } - - -def serialize_site(site: Site) -> dict: - """Serialize Site model using the same schema as AppSiteApi.""" - return cast(dict, marshal(site, AppSiteApi.site_fields)) - - -def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict: - can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo - app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo) - return cast(dict, marshal(app_site_info, AppSiteApi.app_fields)) diff --git a/api/controllers/web/workflow_events.py b/api/controllers/web/workflow_events.py deleted file mode 100644 index 61568e70e6..0000000000 --- a/api/controllers/web/workflow_events.py +++ /dev/null @@ -1,112 +0,0 @@ -""" -Web App Workflow Resume APIs. -""" - -import json -from collections.abc import Generator - -from flask import Response, request -from sqlalchemy.orm import sessionmaker - -from controllers.web import api -from controllers.web.error import InvalidArgumentError, NotFoundError -from controllers.web.wraps import WebApiResource -from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator -from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from core.app.apps.message_generator import MessageGenerator -from core.app.apps.workflow.app_generator import WorkflowAppGenerator -from extensions.ext_database import db -from models.enums import CreatorUserRole -from models.model import App, AppMode, EndUser -from repositories.factory import DifyAPIRepositoryFactory -from services.workflow_event_snapshot_service import build_workflow_event_stream - - -class WorkflowEventsApi(WebApiResource): - """API for getting workflow execution events after resume.""" - - def get(self, app_model: App, end_user: EndUser, task_id: str): - """ - Get workflow execution events stream after resume. - - GET /api/workflow//events - - Returns Server-Sent Events stream. - """ - workflow_run_id = task_id - session_maker = sessionmaker(db.engine) - repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) - workflow_run = repo.get_workflow_run_by_id_and_tenant_id( - tenant_id=app_model.tenant_id, - run_id=workflow_run_id, - ) - - if workflow_run is None: - raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}") - - if workflow_run.app_id != app_model.id: - raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}") - - if workflow_run.created_by_role != CreatorUserRole.END_USER: - raise NotFoundError(f"WorkflowRun not created by end user, id={workflow_run_id}") - - if workflow_run.created_by != end_user.id: - raise NotFoundError(f"WorkflowRun not created by the current end user, id={workflow_run_id}") - - if workflow_run.finished_at is not None: - response = WorkflowResponseConverter.workflow_run_result_to_finish_response( - task_id=workflow_run.id, - workflow_run=workflow_run, - creator_user=end_user, - ) - - payload = response.model_dump(mode="json") - payload["event"] = response.event.value - - def _generate_finished_events() -> Generator[str, None, None]: - yield f"data: {json.dumps(payload)}\n\n" - - event_generator = _generate_finished_events - else: - app_mode = AppMode.value_of(app_model.mode) - msg_generator = MessageGenerator() - generator: BaseAppGenerator - if app_mode == AppMode.ADVANCED_CHAT: - generator = AdvancedChatAppGenerator() - elif app_mode == AppMode.WORKFLOW: - generator = WorkflowAppGenerator() - else: - raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") - - include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true" - - def _generate_stream_events(): - if include_state_snapshot: - return generator.convert_to_event_stream( - build_workflow_event_stream( - app_mode=app_mode, - workflow_run=workflow_run, - tenant_id=app_model.tenant_id, - app_id=app_model.id, - session_maker=session_maker, - ) - ) - return generator.convert_to_event_stream( - msg_generator.retrieve_events(app_mode, workflow_run.id), - ) - - event_generator = _generate_stream_events - - return Response( - event_generator(), - mimetype="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, - ) - - -# Register the APIs -api.add_resource(WorkflowEventsApi, "/workflow//events") diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 2891d3ceeb..528c45f6c8 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -4,8 +4,8 @@ import contextvars import logging import threading import uuid -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload +from collections.abc import Generator, Mapping +from typing import TYPE_CHECKING, Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -29,25 +29,21 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse -from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.repositories import DifyCoreRepositoryFactory -from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.repositories.draft_variable_repository import ( DraftVariableSaverFactory, ) from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.runtime import GraphRuntimeState from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom -from models.base import Base from models.enums import WorkflowRunTriggeredFrom from services.conversation_service import ConversationService from services.workflow_draft_variable_service import ( @@ -69,9 +65,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user: Union[Account, EndUser], args: Mapping[str, Any], invoke_from: InvokeFrom, - workflow_run_id: str, streaming: Literal[False], - pause_state_config: PauseStateLayerConfig | None = None, ) -> Mapping[str, Any]: ... @overload @@ -80,11 +74,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_model: App, workflow: Workflow, user: Union[Account, EndUser], - args: Mapping[str, Any], + args: Mapping, invoke_from: InvokeFrom, - workflow_run_id: str, streaming: Literal[True], - pause_state_config: PauseStateLayerConfig | None = None, ) -> Generator[Mapping | str, None, None]: ... @overload @@ -93,11 +85,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_model: App, workflow: Workflow, user: Union[Account, EndUser], - args: Mapping[str, Any], + args: Mapping, invoke_from: InvokeFrom, - workflow_run_id: str, streaming: bool, - pause_state_config: PauseStateLayerConfig | None = None, ) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ... def generate( @@ -105,11 +95,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_model: App, workflow: Workflow, user: Union[Account, EndUser], - args: Mapping[str, Any], + args: Mapping, invoke_from: InvokeFrom, - workflow_run_id: str, streaming: bool = True, - pause_state_config: PauseStateLayerConfig | None = None, ) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: """ Generate App response. @@ -173,6 +161,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # always enable retriever resource in debugger mode app_config.additional_features.show_retrieve_source = True # type: ignore + workflow_run_id = str(uuid.uuid4()) # init application generate entity application_generate_entity = AdvancedChatAppGenerateEntity( task_id=str(uuid.uuid4()), @@ -190,7 +179,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from=invoke_from, extras=extras, trace_manager=trace_manager, - workflow_run_id=str(workflow_run_id), + workflow_run_id=workflow_run_id, ) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) @@ -227,38 +216,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, conversation=conversation, stream=streaming, - pause_state_config=pause_state_config, - ) - - def resume( - self, - *, - app_model: App, - workflow: Workflow, - user: Union[Account, EndUser], - conversation: Conversation, - message: Message, - application_generate_entity: AdvancedChatAppGenerateEntity, - workflow_execution_repository: WorkflowExecutionRepository, - workflow_node_execution_repository: WorkflowNodeExecutionRepository, - graph_runtime_state: GraphRuntimeState, - pause_state_config: PauseStateLayerConfig | None = None, - ): - """ - Resume a paused advanced chat execution. - """ - return self._generate( - workflow=workflow, - user=user, - invoke_from=application_generate_entity.invoke_from, - application_generate_entity=application_generate_entity, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - conversation=conversation, - message=message, - stream=application_generate_entity.stream, - pause_state_config=pause_state_config, - graph_runtime_state=graph_runtime_state, ) def single_iteration_generate( @@ -439,12 +396,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, conversation: Conversation | None = None, - message: Message | None = None, stream: bool = True, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, - pause_state_config: PauseStateLayerConfig | None = None, - graph_runtime_state: GraphRuntimeState | None = None, - graph_engine_layers: Sequence[GraphEngineLayer] = (), ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: """ Generate App response. @@ -458,12 +411,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param conversation: conversation :param stream: is stream """ - is_first_conversation = conversation is None + is_first_conversation = False + if not conversation: + is_first_conversation = True - if conversation is not None and message is not None: - pass - else: - conversation, message = self._init_generate_records(application_generate_entity, conversation) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) if is_first_conversation: # update conversation features @@ -486,16 +439,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): message_id=message.id, ) - graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) - if pause_state_config is not None: - graph_layers.append( - PauseStatePersistenceLayer( - session_factory=pause_state_config.session_factory, - generate_entity=application_generate_entity, - state_owner_user_id=pause_state_config.state_owner_user_id, - ) - ) - # new thread with request context and contextvars context = contextvars.copy_context() @@ -511,25 +454,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): "variable_loader": variable_loader, "workflow_execution_repository": workflow_execution_repository, "workflow_node_execution_repository": workflow_node_execution_repository, - "graph_engine_layers": tuple(graph_layers), - "graph_runtime_state": graph_runtime_state, }, ) worker_thread.start() # release database connection, because the following new thread operations may take a long time - with Session(bind=db.engine, expire_on_commit=False) as session: - workflow = _refresh_model(session, workflow) - message = _refresh_model(session, message) - # workflow_ = session.get(Workflow, workflow.id) - # assert workflow_ is not None - # workflow = workflow_ - # message_ = session.get(Message, message.id) - # assert message_ is not None - # message = message_ - # db.session.refresh(workflow) - # db.session.refresh(message) + db.session.refresh(workflow) + db.session.refresh(message) # db.session.refresh(user) db.session.close() @@ -558,8 +490,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): variable_loader: VariableLoader, workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, - graph_engine_layers: Sequence[GraphEngineLayer] = (), - graph_runtime_state: GraphRuntimeState | None = None, ): """ Generate worker in a new thread. @@ -617,8 +547,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app=app, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, - graph_engine_layers=graph_engine_layers, - graph_runtime_state=graph_runtime_state, ) try: @@ -686,13 +614,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): else: logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id) raise e - - -_T = TypeVar("_T", bound=Base) - - -def _refresh_model(session, model: _T) -> _T: - with Session(bind=db.engine, expire_on_commit=False) as session: - detach_model = session.get(type(model), model.id) - assert detach_model is not None - return detach_model diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 8b20442eab..d702db0908 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -66,7 +66,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, graph_engine_layers: Sequence[GraphEngineLayer] = (), - graph_runtime_state: GraphRuntimeState | None = None, ): super().__init__( queue_manager=queue_manager, @@ -83,7 +82,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self._app = app self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository - self._resume_graph_runtime_state = graph_runtime_state @trace_span(WorkflowAppRunnerHandler) def run(self): @@ -112,21 +110,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): invoke_from = InvokeFrom.DEBUGGER user_from = self._resolve_user_from(invoke_from) - resume_state = self._resume_graph_runtime_state - - if resume_state is not None: - graph_runtime_state = resume_state - variable_pool = graph_runtime_state.variable_pool - graph = self._init_graph( - graph_config=self._workflow.graph_dict, - graph_runtime_state=graph_runtime_state, - workflow_id=self._workflow.id, - tenant_id=self._workflow.tenant_id, - user_id=self.application_generate_entity.user_id, - invoke_from=invoke_from, - user_from=user_from, - ) - elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: + if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: # Handle single iteration or single loop run graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution( workflow=self._workflow, diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 00a6a3d9af..da1e9f19b6 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -24,8 +24,6 @@ from core.app.entities.queue_entities import ( QueueAgentLogEvent, QueueAnnotationReplyEvent, QueueErrorEvent, - QueueHumanInputFormFilledEvent, - QueueHumanInputFormTimeoutEvent, QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, @@ -44,7 +42,6 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, QueueWorkflowFailedEvent, QueueWorkflowPartialSuccessEvent, - QueueWorkflowPausedEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, WorkflowQueueMessage, @@ -66,8 +63,6 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.workflow.entities.pause_reason import HumanInputRequired from core.workflow.enums import WorkflowExecutionStatus from core.workflow.nodes import NodeType from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory @@ -76,8 +71,7 @@ from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile -from models.enums import CreatorUserRole, MessageStatus -from models.execution_extra_content import HumanInputContent +from models.enums import CreatorUserRole from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -134,7 +128,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) self._task_state = WorkflowTaskState() - self._seed_task_state_from_message(message) self._message_cycle_manager = MessageCycleManager( application_generate_entity=application_generate_entity, task_state=self._task_state ) @@ -142,7 +135,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._application_generate_entity = application_generate_entity self._workflow_id = workflow.id self._workflow_features_dict = workflow.features_dict - self._workflow_tenant_id = workflow.tenant_id self._conversation_id = conversation.id self._conversation_mode = conversation.mode self._message_id = message.id @@ -152,13 +144,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._workflow_run_id: str = "" self._draft_var_saver_factory = draft_var_saver_factory self._graph_runtime_state: GraphRuntimeState | None = None - self._message_saved_on_pause = False self._seed_graph_runtime_state_from_queue_manager() - def _seed_task_state_from_message(self, message: Message) -> None: - if message.status == MessageStatus.PAUSED and message.answer: - self._task_state.answer = message.answer - def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ Process generate task pipeline. @@ -321,7 +308,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): task_id=self._application_generate_entity.task_id, workflow_run_id=run_id, workflow_id=self._workflow_id, - reason=event.reason, ) yield workflow_start_resp @@ -539,35 +525,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) yield workflow_finish_resp - - def _handle_workflow_paused_event( - self, - event: QueueWorkflowPausedEvent, - **kwargs, - ) -> Generator[StreamResponse, None, None]: - """Handle workflow paused events.""" - validated_state = self._ensure_graph_runtime_initialized() - responses = self._workflow_response_converter.workflow_pause_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - graph_runtime_state=validated_state, - ) - for reason in event.reasons: - if isinstance(reason, HumanInputRequired): - self._persist_human_input_extra_content(form_id=reason.form_id, node_id=reason.node_id) - yield from responses - resolved_state: GraphRuntimeState | None = None - try: - resolved_state = self._ensure_graph_runtime_initialized() - except ValueError: - resolved_state = None - - with self._database_session() as session: - self._save_message(session=session, graph_runtime_state=resolved_state) - message = self._get_message(session=session) - if message is not None: - message.status = MessageStatus.PAUSED - self._message_saved_on_pause = True self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) def _handle_workflow_failed_event( @@ -657,10 +614,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, ) - # Save message unless it has already been persisted on pause. - if not self._message_saved_on_pause: - with self._database_session() as session: - self._save_message(session=session, graph_runtime_state=resolved_state) + # Save message + with self._database_session() as session: + self._save_message(session=session, graph_runtime_state=resolved_state) yield self._message_end_to_stream_response() @@ -686,65 +642,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): """Handle message replace events.""" yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text, reason=event.reason) - def _handle_human_input_form_filled_event( - self, event: QueueHumanInputFormFilledEvent, **kwargs - ) -> Generator[StreamResponse, None, None]: - """Handle human input form filled events.""" - self._persist_human_input_extra_content(node_id=event.node_id) - yield self._workflow_response_converter.human_input_form_filled_to_stream_response( - event=event, task_id=self._application_generate_entity.task_id - ) - - def _handle_human_input_form_timeout_event( - self, event: QueueHumanInputFormTimeoutEvent, **kwargs - ) -> Generator[StreamResponse, None, None]: - """Handle human input form timeout events.""" - yield self._workflow_response_converter.human_input_form_timeout_to_stream_response( - event=event, task_id=self._application_generate_entity.task_id - ) - - def _persist_human_input_extra_content(self, *, node_id: str | None = None, form_id: str | None = None) -> None: - if not self._workflow_run_id or not self._message_id: - return - - if form_id is None: - if node_id is None: - return - form_id = self._load_human_input_form_id(node_id=node_id) - if form_id is None: - logger.warning( - "HumanInput form not found for workflow run %s node %s", - self._workflow_run_id, - node_id, - ) - return - - with self._database_session() as session: - exists_stmt = select(HumanInputContent).where( - HumanInputContent.workflow_run_id == self._workflow_run_id, - HumanInputContent.message_id == self._message_id, - HumanInputContent.form_id == form_id, - ) - if session.scalar(exists_stmt) is not None: - return - - content = HumanInputContent( - workflow_run_id=self._workflow_run_id, - message_id=self._message_id, - form_id=form_id, - ) - session.add(content) - - def _load_human_input_form_id(self, *, node_id: str) -> str | None: - form_repository = HumanInputFormRepositoryImpl( - session_factory=db.engine, - tenant_id=self._workflow_tenant_id, - ) - form = form_repository.get_form(self._workflow_run_id, node_id) - if form is None: - return None - return form.id - def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]: """Handle agent log events.""" yield self._workflow_response_converter.handle_agent_log( @@ -762,7 +659,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): QueueWorkflowStartedEvent: self._handle_workflow_started_event, QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event, QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event, - QueueWorkflowPausedEvent: self._handle_workflow_paused_event, QueueWorkflowFailedEvent: self._handle_workflow_failed_event, # Node events QueueNodeRetryEvent: self._handle_node_retry_event, @@ -784,8 +680,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): QueueMessageReplaceEvent: self._handle_message_replace_event, QueueAdvancedChatMessageEndEvent: self._handle_advanced_chat_message_end_event, QueueAgentLogEvent: self._handle_agent_log_event, - QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event, - QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event, } def _dispatch_event( @@ -853,9 +747,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): case QueueWorkflowFailedEvent(): yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager) break - case QueueWorkflowPausedEvent(): - yield from self._handle_workflow_paused_event(event) - break case QueueStopEvent(): yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager) @@ -881,11 +772,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None): message = self._get_message(session=session) - if message is None: - return - - if message.status == MessageStatus.PAUSED: - message.status = MessageStatus.NORMAL # If there are assistant files, remove markdown image links from answer answer_text = self._task_state.answer diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 6d329063f8..38ecec5d30 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -5,14 +5,9 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, NewType, Union -from sqlalchemy import select -from sqlalchemy.orm import Session - from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( QueueAgentLogEvent, - QueueHumanInputFormFilledEvent, - QueueHumanInputFormTimeoutEvent, QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, @@ -24,13 +19,9 @@ from core.app.entities.queue_entities import ( QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, - QueueWorkflowPausedEvent, ) from core.app.entities.task_entities import ( AgentLogStreamResponse, - HumanInputFormFilledResponse, - HumanInputFormTimeoutResponse, - HumanInputRequiredResponse, IterationNodeCompletedStreamResponse, IterationNodeNextStreamResponse, IterationNodeStartStreamResponse, @@ -40,9 +31,7 @@ from core.app.entities.task_entities import ( NodeFinishStreamResponse, NodeRetryStreamResponse, NodeStartStreamResponse, - StreamResponse, WorkflowFinishStreamResponse, - WorkflowPauseStreamResponse, WorkflowStartStreamResponse, ) from core.file import FILE_MODEL_IDENTITY, File @@ -51,8 +40,6 @@ from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.trigger.trigger_manager import TriggerManager from core.variables.segments import ArrayFileSegment, FileSegment, Segment -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import ( NodeType, SystemVariableKey, @@ -64,11 +51,8 @@ from core.workflow.runtime import GraphRuntimeState from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, EndUser -from models.human_input import HumanInputForm -from models.workflow import WorkflowRun from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator NodeExecutionId = NewType("NodeExecutionId", str) @@ -207,7 +191,6 @@ class WorkflowResponseConverter: task_id: str, workflow_run_id: str, workflow_id: str, - reason: WorkflowStartReason, ) -> WorkflowStartStreamResponse: run_id = self._ensure_workflow_run_id(workflow_run_id) started_at = naive_utc_now() @@ -221,7 +204,6 @@ class WorkflowResponseConverter: workflow_id=workflow_id, inputs=self._workflow_inputs, created_at=int(started_at.timestamp()), - reason=reason, ), ) @@ -282,160 +264,6 @@ class WorkflowResponseConverter: ), ) - def workflow_pause_to_stream_response( - self, - *, - event: QueueWorkflowPausedEvent, - task_id: str, - graph_runtime_state: GraphRuntimeState, - ) -> list[StreamResponse]: - run_id = self._ensure_workflow_run_id() - started_at = self._workflow_started_at - if started_at is None: - raise ValueError( - "workflow_pause_to_stream_response called before workflow_start_to_stream_response", - ) - paused_at = naive_utc_now() - elapsed_time = (paused_at - started_at).total_seconds() - encoded_outputs = self._encode_outputs(event.outputs) or {} - if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API: - encoded_outputs = {} - pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons] - human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)] - expiration_times_by_form_id: dict[str, datetime] = {} - if human_input_form_ids: - stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where( - HumanInputForm.id.in_(human_input_form_ids) - ) - with Session(bind=db.engine) as session: - for form_id, expiration_time in session.execute(stmt): - expiration_times_by_form_id[str(form_id)] = expiration_time - - responses: list[StreamResponse] = [] - - for reason in event.reasons: - if isinstance(reason, HumanInputRequired): - expiration_time = expiration_times_by_form_id.get(reason.form_id) - if expiration_time is None: - raise ValueError(f"HumanInputForm not found for pause reason, form_id={reason.form_id}") - responses.append( - HumanInputRequiredResponse( - task_id=task_id, - workflow_run_id=run_id, - data=HumanInputRequiredResponse.Data( - form_id=reason.form_id, - node_id=reason.node_id, - node_title=reason.node_title, - form_content=reason.form_content, - inputs=reason.inputs, - actions=reason.actions, - display_in_ui=reason.display_in_ui, - form_token=reason.form_token, - resolved_default_values=reason.resolved_default_values, - expiration_time=int(expiration_time.timestamp()), - ), - ) - ) - - responses.append( - WorkflowPauseStreamResponse( - task_id=task_id, - workflow_run_id=run_id, - data=WorkflowPauseStreamResponse.Data( - workflow_run_id=run_id, - paused_nodes=list(event.paused_nodes), - outputs=encoded_outputs, - reasons=pause_reasons, - status=WorkflowExecutionStatus.PAUSED.value, - created_at=int(started_at.timestamp()), - elapsed_time=elapsed_time, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - ), - ) - ) - - return responses - - def human_input_form_filled_to_stream_response( - self, *, event: QueueHumanInputFormFilledEvent, task_id: str - ) -> HumanInputFormFilledResponse: - run_id = self._ensure_workflow_run_id() - return HumanInputFormFilledResponse( - task_id=task_id, - workflow_run_id=run_id, - data=HumanInputFormFilledResponse.Data( - node_id=event.node_id, - node_title=event.node_title, - rendered_content=event.rendered_content, - action_id=event.action_id, - action_text=event.action_text, - ), - ) - - def human_input_form_timeout_to_stream_response( - self, *, event: QueueHumanInputFormTimeoutEvent, task_id: str - ) -> HumanInputFormTimeoutResponse: - run_id = self._ensure_workflow_run_id() - return HumanInputFormTimeoutResponse( - task_id=task_id, - workflow_run_id=run_id, - data=HumanInputFormTimeoutResponse.Data( - node_id=event.node_id, - node_title=event.node_title, - expiration_time=int(event.expiration_time.timestamp()), - ), - ) - - @classmethod - def workflow_run_result_to_finish_response( - cls, - *, - task_id: str, - workflow_run: WorkflowRun, - creator_user: Account | EndUser, - ) -> WorkflowFinishStreamResponse: - run_id = workflow_run.id - elapsed_time = workflow_run.elapsed_time - - encoded_outputs = workflow_run.outputs_dict - finished_at = workflow_run.finished_at - assert finished_at is not None - - created_by: Mapping[str, object] - user = creator_user - if isinstance(user, Account): - created_by = { - "id": user.id, - "name": user.name, - "email": user.email, - } - else: - created_by = { - "id": user.id, - "user": user.session_id, - } - - return WorkflowFinishStreamResponse( - task_id=task_id, - workflow_run_id=run_id, - data=WorkflowFinishStreamResponse.Data( - id=run_id, - workflow_id=workflow_run.workflow_id, - status=workflow_run.status.value, - outputs=encoded_outputs, - error=workflow_run.error, - elapsed_time=elapsed_time, - total_tokens=workflow_run.total_tokens, - total_steps=workflow_run.total_steps, - created_by=created_by, - created_at=int(workflow_run.created_at.timestamp()), - finished_at=int(finished_at.timestamp()), - files=cls.fetch_files_from_node_outputs(encoded_outputs), - exceptions_count=workflow_run.exceptions_count, - ), - ) - def workflow_node_start_to_stream_response( self, *, @@ -764,8 +592,7 @@ class WorkflowResponseConverter: ), ) - @classmethod - def fetch_files_from_node_outputs(cls, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]: + def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]: """ Fetch files from node outputs :param outputs_dict: node outputs dict @@ -774,7 +601,7 @@ class WorkflowResponseConverter: if not outputs_dict: return [] - files = [cls._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()] + files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()] # Remove None files = [file for file in files if file] # Flatten list diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 4e9a191dae..57617d8863 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,6 +1,6 @@ import json import logging -from collections.abc import Callable, Generator, Mapping +from collections.abc import Generator from typing import Union, cast from sqlalchemy import select @@ -10,14 +10,12 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.exc import GenerateTaskStoppedError -from core.app.apps.streaming_utils import stream_topic_events from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, AppGenerateEntity, ChatAppGenerateEntity, CompletionAppGenerateEntity, - ConversationAppGenerateEntity, InvokeFrom, ) from core.app.entities.task_entities import ( @@ -29,8 +27,6 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db -from extensions.ext_redis import get_pubsub_broadcast_channel -from libs.broadcast_channel.channel import Topic from libs.datetime_utils import naive_utc_now from models import Account from models.enums import CreatorUserRole @@ -160,7 +156,6 @@ class MessageBasedAppGenerator(BaseAppGenerator): query = application_generate_entity.query or "New conversation" conversation_name = (query[:20] + "…") if len(query) > 20 else query - created_new_conversation = conversation is None try: if not conversation: conversation = Conversation( @@ -237,10 +232,6 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.add_all(message_files) db.session.commit() - - if isinstance(application_generate_entity, ConversationAppGenerateEntity): - application_generate_entity.conversation_id = conversation.id - application_generate_entity.is_new_conversation = created_new_conversation return conversation, message except Exception: db.session.rollback() @@ -293,29 +284,3 @@ class MessageBasedAppGenerator(BaseAppGenerator): raise MessageNotExistsError("Message not exists") return message - - @staticmethod - def _make_channel_key(app_mode: AppMode, workflow_run_id: str): - return f"channel:{app_mode}:{workflow_run_id}" - - @classmethod - def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic: - key = cls._make_channel_key(app_mode, workflow_run_id) - channel = get_pubsub_broadcast_channel() - topic = channel.topic(key) - return topic - - @classmethod - def retrieve_events( - cls, - app_mode: AppMode, - workflow_run_id: str, - idle_timeout=300, - on_subscribe: Callable[[], None] | None = None, - ) -> Generator[Mapping | str, None, None]: - topic = cls.get_response_topic(app_mode, workflow_run_id) - return stream_topic_events( - topic=topic, - idle_timeout=idle_timeout, - on_subscribe=on_subscribe, - ) diff --git a/api/core/app/apps/message_generator.py b/api/core/app/apps/message_generator.py deleted file mode 100644 index 68631bb230..0000000000 --- a/api/core/app/apps/message_generator.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Callable, Generator, Mapping - -from core.app.apps.streaming_utils import stream_topic_events -from extensions.ext_redis import get_pubsub_broadcast_channel -from libs.broadcast_channel.channel import Topic -from models.model import AppMode - - -class MessageGenerator: - @staticmethod - def _make_channel_key(app_mode: AppMode, workflow_run_id: str): - return f"channel:{app_mode}:{str(workflow_run_id)}" - - @classmethod - def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic: - key = cls._make_channel_key(app_mode, workflow_run_id) - channel = get_pubsub_broadcast_channel() - topic = channel.topic(key) - return topic - - @classmethod - def retrieve_events( - cls, - app_mode: AppMode, - workflow_run_id: str, - idle_timeout=300, - ping_interval: float = 10.0, - on_subscribe: Callable[[], None] | None = None, - ) -> Generator[Mapping | str, None, None]: - topic = cls.get_response_topic(app_mode, workflow_run_id) - return stream_topic_events( - topic=topic, - idle_timeout=idle_timeout, - ping_interval=ping_interval, - on_subscribe=on_subscribe, - ) diff --git a/api/core/app/apps/streaming_utils.py b/api/core/app/apps/streaming_utils.py deleted file mode 100644 index 57d4b537a4..0000000000 --- a/api/core/app/apps/streaming_utils.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import annotations - -import json -import time -from collections.abc import Callable, Generator, Iterable, Mapping -from typing import Any - -from core.app.entities.task_entities import StreamEvent -from libs.broadcast_channel.channel import Topic -from libs.broadcast_channel.exc import SubscriptionClosedError - - -def stream_topic_events( - *, - topic: Topic, - idle_timeout: float, - ping_interval: float | None = None, - on_subscribe: Callable[[], None] | None = None, - terminal_events: Iterable[str | StreamEvent] | None = None, -) -> Generator[Mapping[str, Any] | str, None, None]: - # send a PING event immediately to prevent the connection staying in pending state for a long time. - # - # This simplify the debugging process as the DevTools in Chrome does not - # provide complete curl command for pending connections. - yield StreamEvent.PING.value - - terminal_values = _normalize_terminal_events(terminal_events) - last_msg_time = time.time() - last_ping_time = last_msg_time - with topic.subscribe() as sub: - # on_subscribe fires only after the Redis subscription is active. - # This is used to gate task start and reduce pub/sub race for the first event. - if on_subscribe is not None: - on_subscribe() - while True: - try: - msg = sub.receive(timeout=0.1) - except SubscriptionClosedError: - return - if msg is None: - current_time = time.time() - if current_time - last_msg_time > idle_timeout: - return - if ping_interval is not None and current_time - last_ping_time >= ping_interval: - yield StreamEvent.PING.value - last_ping_time = current_time - continue - - last_msg_time = time.time() - last_ping_time = last_msg_time - event = json.loads(msg) - yield event - if not isinstance(event, dict): - continue - - event_type = event.get("event") - if event_type in terminal_values: - return - - -def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]: - if not terminal_events: - return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value} - values: set[str] = set() - for item in terminal_events: - if isinstance(item, StreamEvent): - values.add(item.value) - else: - values.add(str(item)) - return values diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index dc5852d552..ee205ed153 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -25,7 +25,6 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse -from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.db.session_factory import session_factory from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.model_runtime.errors.invoke import InvokeAuthorizationError @@ -35,15 +34,12 @@ from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.runtime import GraphRuntimeState from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts -from models.account import Account +from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.enums import WorkflowRunTriggeredFrom -from models.model import App, EndUser -from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService if TYPE_CHECKING: @@ -70,11 +66,9 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: Literal[True], call_depth: int, - workflow_run_id: str | uuid.UUID | None = None, triggered_from: WorkflowRunTriggeredFrom | None = None, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), - pause_state_config: PauseStateLayerConfig | None = None, ) -> Generator[Mapping[str, Any] | str, None, None]: ... @overload @@ -88,11 +82,9 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: Literal[False], call_depth: int, - workflow_run_id: str | uuid.UUID | None = None, triggered_from: WorkflowRunTriggeredFrom | None = None, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), - pause_state_config: PauseStateLayerConfig | None = None, ) -> Mapping[str, Any]: ... @overload @@ -106,11 +98,9 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: bool, call_depth: int, - workflow_run_id: str | uuid.UUID | None = None, triggered_from: WorkflowRunTriggeredFrom | None = None, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), - pause_state_config: PauseStateLayerConfig | None = None, ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ... def generate( @@ -123,11 +113,9 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: bool = True, call_depth: int = 0, - workflow_run_id: str | uuid.UUID | None = None, triggered_from: WorkflowRunTriggeredFrom | None = None, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), - pause_state_config: PauseStateLayerConfig | None = None, ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: files: Sequence[Mapping[str, Any]] = args.get("files") or [] @@ -162,7 +150,7 @@ class WorkflowAppGenerator(BaseAppGenerator): extras = { **extract_external_trace_id_from_args(args), } - workflow_run_id = str(workflow_run_id or uuid.uuid4()) + workflow_run_id = str(uuid.uuid4()) # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args # trigger shouldn't prepare user inputs if self._should_prepare_user_inputs(args): @@ -228,40 +216,13 @@ class WorkflowAppGenerator(BaseAppGenerator): streaming=streaming, root_node_id=root_node_id, graph_engine_layers=graph_engine_layers, - pause_state_config=pause_state_config, ) - def resume( - self, - *, - app_model: App, - workflow: Workflow, - user: Union[Account, EndUser], - application_generate_entity: WorkflowAppGenerateEntity, - graph_runtime_state: GraphRuntimeState, - workflow_execution_repository: WorkflowExecutionRepository, - workflow_node_execution_repository: WorkflowNodeExecutionRepository, - graph_engine_layers: Sequence[GraphEngineLayer] = (), - pause_state_config: PauseStateLayerConfig | None = None, - variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, - ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: + def resume(self, *, workflow_run_id: str) -> None: """ - Resume a paused workflow execution using the persisted runtime state. + @TBD """ - return self._generate( - app_model=app_model, - workflow=workflow, - user=user, - application_generate_entity=application_generate_entity, - invoke_from=application_generate_entity.invoke_from, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - streaming=application_generate_entity.stream, - variable_loader=variable_loader, - graph_engine_layers=graph_engine_layers, - graph_runtime_state=graph_runtime_state, - pause_state_config=pause_state_config, - ) + pass def _generate( self, @@ -277,8 +238,6 @@ class WorkflowAppGenerator(BaseAppGenerator): variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), - graph_runtime_state: GraphRuntimeState | None = None, - pause_state_config: PauseStateLayerConfig | None = None, ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: """ Generate App response. @@ -292,8 +251,6 @@ class WorkflowAppGenerator(BaseAppGenerator): :param workflow_node_execution_repository: repository for workflow node execution :param streaming: is stream """ - graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) - # init queue manager queue_manager = WorkflowAppQueueManager( task_id=application_generate_entity.task_id, @@ -302,15 +259,6 @@ class WorkflowAppGenerator(BaseAppGenerator): app_mode=app_model.mode, ) - if pause_state_config is not None: - graph_layers.append( - PauseStatePersistenceLayer( - session_factory=pause_state_config.session_factory, - generate_entity=application_generate_entity, - state_owner_user_id=pause_state_config.state_owner_user_id, - ) - ) - # new thread with request context and contextvars context = contextvars.copy_context() @@ -328,8 +276,7 @@ class WorkflowAppGenerator(BaseAppGenerator): "root_node_id": root_node_id, "workflow_execution_repository": workflow_execution_repository, "workflow_node_execution_repository": workflow_node_execution_repository, - "graph_engine_layers": tuple(graph_layers), - "graph_runtime_state": graph_runtime_state, + "graph_engine_layers": graph_engine_layers, }, ) @@ -431,7 +378,6 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, variable_loader=var_loader, - pause_state_config=None, ) def single_loop_generate( @@ -513,7 +459,6 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, variable_loader=var_loader, - pause_state_config=None, ) def _generate_worker( @@ -527,7 +472,6 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository: WorkflowNodeExecutionRepository, root_node_id: str | None = None, graph_engine_layers: Sequence[GraphEngineLayer] = (), - graph_runtime_state: GraphRuntimeState | None = None, ) -> None: """ Generate worker in a new thread. @@ -573,7 +517,6 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, root_node_id=root_node_id, graph_engine_layers=graph_engine_layers, - graph_runtime_state=graph_runtime_state, ) try: diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index a43f7879d6..0ee3c177f2 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -42,7 +42,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, graph_engine_layers: Sequence[GraphEngineLayer] = (), - graph_runtime_state: GraphRuntimeState | None = None, ): super().__init__( queue_manager=queue_manager, @@ -56,7 +55,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): self._root_node_id = root_node_id self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository - self._resume_graph_runtime_state = graph_runtime_state @trace_span(WorkflowAppRunnerHandler) def run(self): @@ -65,28 +63,23 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): """ app_config = self.application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) + + system_inputs = SystemVariable( + files=self.application_generate_entity.files, + user_id=self._sys_user_id, + app_id=app_config.app_id, + timestamp=int(naive_utc_now().timestamp()), + workflow_id=app_config.workflow_id, + workflow_execution_id=self.application_generate_entity.workflow_execution_id, + ) + invoke_from = self.application_generate_entity.invoke_from # if only single iteration or single loop run is requested if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: invoke_from = InvokeFrom.DEBUGGER user_from = self._resolve_user_from(invoke_from) - resume_state = self._resume_graph_runtime_state - - if resume_state is not None: - graph_runtime_state = resume_state - variable_pool = graph_runtime_state.variable_pool - graph = self._init_graph( - graph_config=self._workflow.graph_dict, - graph_runtime_state=graph_runtime_state, - workflow_id=self._workflow.id, - tenant_id=self._workflow.tenant_id, - user_id=self.application_generate_entity.user_id, - user_from=user_from, - invoke_from=invoke_from, - root_node_id=self._root_node_id, - ) - elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: + if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution( workflow=self._workflow, single_iteration_run=self.application_generate_entity.single_iteration_run, @@ -96,14 +89,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): inputs = self.application_generate_entity.inputs # Create a variable pool. - system_inputs = SystemVariable( - files=self.application_generate_entity.files, - user_id=self._sys_user_id, - app_id=app_config.app_id, - timestamp=int(naive_utc_now().timestamp()), - workflow_id=app_config.workflow_id, - workflow_execution_id=self.application_generate_entity.workflow_execution_id, - ) + variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, @@ -112,6 +98,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + # init graph graph = self._init_graph( graph_config=self._workflow.graph_dict, graph_runtime_state=graph_runtime_state, diff --git a/api/core/app/apps/workflow/errors.py b/api/core/app/apps/workflow/errors.py deleted file mode 100644 index 16cd864209..0000000000 --- a/api/core/app/apps/workflow/errors.py +++ /dev/null @@ -1,7 +0,0 @@ -from libs.exception import BaseHTTPException - - -class WorkflowPausedInBlockingModeError(BaseHTTPException): - error_code = "workflow_paused_in_blocking_mode" - description = "Workflow execution paused for human input; blocking response mode is not supported." - code = 400 diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 0a567a4315..842ad545ad 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -16,8 +16,6 @@ from core.app.entities.queue_entities import ( MessageQueueMessage, QueueAgentLogEvent, QueueErrorEvent, - QueueHumanInputFormFilledEvent, - QueueHumanInputFormTimeoutEvent, QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, @@ -34,7 +32,6 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, QueueWorkflowFailedEvent, QueueWorkflowPartialSuccessEvent, - QueueWorkflowPausedEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, WorkflowQueueMessage, @@ -49,13 +46,11 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, WorkflowFinishStreamResponse, - WorkflowPauseStreamResponse, WorkflowStartStreamResponse, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import WorkflowExecutionStatus from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.runtime import GraphRuntimeState @@ -137,25 +132,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): for stream_response in generator: if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err - elif isinstance(stream_response, WorkflowPauseStreamResponse): - response = WorkflowAppBlockingResponse( - task_id=self._application_generate_entity.task_id, - workflow_run_id=stream_response.data.workflow_run_id, - data=WorkflowAppBlockingResponse.Data( - id=stream_response.data.workflow_run_id, - workflow_id=self._workflow.id, - status=stream_response.data.status, - outputs=stream_response.data.outputs or {}, - error=None, - elapsed_time=stream_response.data.elapsed_time, - total_tokens=stream_response.data.total_tokens, - total_steps=stream_response.data.total_steps, - created_at=stream_response.data.created_at, - finished_at=None, - ), - ) - - return response elif isinstance(stream_response, WorkflowFinishStreamResponse): response = WorkflowAppBlockingResponse( task_id=self._application_generate_entity.task_id, @@ -170,7 +146,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): total_tokens=stream_response.data.total_tokens, total_steps=stream_response.data.total_steps, created_at=int(stream_response.data.created_at), - finished_at=int(stream_response.data.finished_at) if stream_response.data.finished_at else None, + finished_at=int(stream_response.data.finished_at), ), ) @@ -283,15 +259,13 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): run_id = self._extract_workflow_run_id(runtime_state) self._workflow_execution_id = run_id - if event.reason == WorkflowStartReason.INITIAL: - with self._database_session() as session: - self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id) + with self._database_session() as session: + self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id) start_resp = self._workflow_response_converter.workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run_id=run_id, workflow_id=self._workflow.id, - reason=event.reason, ) yield start_resp @@ -466,21 +440,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) yield workflow_finish_resp - def _handle_workflow_paused_event( - self, - event: QueueWorkflowPausedEvent, - **kwargs, - ) -> Generator[StreamResponse, None, None]: - """Handle workflow paused events.""" - self._ensure_workflow_initialized() - validated_state = self._ensure_graph_runtime_initialized() - responses = self._workflow_response_converter.workflow_pause_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - graph_runtime_state=validated_state, - ) - yield from responses - def _handle_workflow_failed_and_stop_events( self, event: Union[QueueWorkflowFailedEvent, QueueStopEvent], @@ -536,22 +495,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): task_id=self._application_generate_entity.task_id, event=event ) - def _handle_human_input_form_filled_event( - self, event: QueueHumanInputFormFilledEvent, **kwargs - ) -> Generator[StreamResponse, None, None]: - """Handle human input form filled events.""" - yield self._workflow_response_converter.human_input_form_filled_to_stream_response( - event=event, task_id=self._application_generate_entity.task_id - ) - - def _handle_human_input_form_timeout_event( - self, event: QueueHumanInputFormTimeoutEvent, **kwargs - ) -> Generator[StreamResponse, None, None]: - """Handle human input form timeout events.""" - yield self._workflow_response_converter.human_input_form_timeout_to_stream_response( - event=event, task_id=self._application_generate_entity.task_id - ) - def _get_event_handlers(self) -> dict[type, Callable]: """Get mapping of event types to their handlers using fluent pattern.""" return { @@ -563,7 +506,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): QueueWorkflowStartedEvent: self._handle_workflow_started_event, QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event, QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event, - QueueWorkflowPausedEvent: self._handle_workflow_paused_event, # Node events QueueNodeRetryEvent: self._handle_node_retry_event, QueueNodeStartedEvent: self._handle_node_started_event, @@ -578,8 +520,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): QueueLoopCompletedEvent: self._handle_loop_completed_event, # Agent events QueueAgentLogEvent: self._handle_agent_log_event, - QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event, - QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event, } def _dispatch_event( @@ -662,9 +602,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): case QueueWorkflowFailedEvent(): yield from self._handle_workflow_failed_and_stop_events(event) break - case QueueWorkflowPausedEvent(): - yield from self._handle_workflow_paused_event(event) - break case QueueStopEvent(): yield from self._handle_workflow_failed_and_stop_events(event) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index c9d7464c17..13b7865f55 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -1,4 +1,3 @@ -import logging import time from collections.abc import Mapping, Sequence from typing import Any, cast @@ -8,8 +7,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, QueueAgentLogEvent, - QueueHumanInputFormFilledEvent, - QueueHumanInputFormTimeoutEvent, QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, @@ -25,27 +22,22 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, QueueWorkflowFailedEvent, QueueWorkflowPartialSuccessEvent, - QueueWorkflowPausedEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams -from core.workflow.entities.pause_reason import HumanInputRequired from core.workflow.graph import Graph from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events import ( GraphEngineEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, - GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunAgentLogEvent, NodeRunExceptionEvent, NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, NodeRunIterationFailedEvent, NodeRunIterationNextEvent, NodeRunIterationStartedEvent, @@ -69,9 +61,6 @@ from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, from core.workflow.workflow_entry import WorkflowEntry from models.enums import UserFrom from models.workflow import Workflow -from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task - -logger = logging.getLogger(__name__) class WorkflowBasedAppRunner: @@ -338,7 +327,7 @@ class WorkflowBasedAppRunner: :param event: event """ if isinstance(event, GraphRunStartedEvent): - self._publish_event(QueueWorkflowStartedEvent(reason=event.reason)) + self._publish_event(QueueWorkflowStartedEvent()) elif isinstance(event, GraphRunSucceededEvent): self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs)) elif isinstance(event, GraphRunPartialSucceededEvent): @@ -349,38 +338,6 @@ class WorkflowBasedAppRunner: self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)) elif isinstance(event, GraphRunAbortedEvent): self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0)) - elif isinstance(event, GraphRunPausedEvent): - runtime_state = workflow_entry.graph_engine.graph_runtime_state - paused_nodes = runtime_state.get_paused_nodes() - self._enqueue_human_input_notifications(event.reasons) - self._publish_event( - QueueWorkflowPausedEvent( - reasons=event.reasons, - outputs=event.outputs, - paused_nodes=paused_nodes, - ) - ) - elif isinstance(event, NodeRunHumanInputFormFilledEvent): - self._publish_event( - QueueHumanInputFormFilledEvent( - node_execution_id=event.id, - node_id=event.node_id, - node_type=event.node_type, - node_title=event.node_title, - rendered_content=event.rendered_content, - action_id=event.action_id, - action_text=event.action_text, - ) - ) - elif isinstance(event, NodeRunHumanInputFormTimeoutEvent): - self._publish_event( - QueueHumanInputFormTimeoutEvent( - node_id=event.node_id, - node_type=event.node_type, - node_title=event.node_title, - expiration_time=event.expiration_time, - ) - ) elif isinstance(event, NodeRunRetryEvent): node_run_result = event.node_run_result inputs = node_run_result.inputs @@ -587,19 +544,5 @@ class WorkflowBasedAppRunner: ) ) - def _enqueue_human_input_notifications(self, reasons: Sequence[object]) -> None: - for reason in reasons: - if not isinstance(reason, HumanInputRequired): - continue - if not reason.form_id: - continue - try: - dispatch_human_input_email_task.apply_async( - kwargs={"form_id": reason.form_id, "node_title": reason.node_title}, - queue="mail", - ) - except Exception: # pragma: no cover - defensive logging - logger.exception("Failed to enqueue human input email task for form %s", reason.form_id) - def _publish_event(self, event: AppQueueEvent): self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 0e68e554c8..5bc453420d 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -132,7 +132,7 @@ class AppGenerateEntity(BaseModel): extras: dict[str, Any] = Field(default_factory=dict) # tracing instance - trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False) + trace_manager: Optional["TraceQueueManager"] = None class EasyUIBasedAppGenerateEntity(AppGenerateEntity): @@ -156,7 +156,6 @@ class ConversationAppGenerateEntity(AppGenerateEntity): """ conversation_id: str | None = None - is_new_conversation: bool = False parent_message_id: str | None = Field( default=None, description=( diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 5b2fa29b56..77d6bf03b4 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -8,8 +8,6 @@ from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities import AgentNodeStrategyInit -from core.workflow.entities.pause_reason import PauseReason -from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import WorkflowNodeExecutionMetadataKey from core.workflow.nodes import NodeType @@ -48,9 +46,6 @@ class QueueEvent(StrEnum): PING = "ping" STOP = "stop" RETRY = "retry" - PAUSE = "pause" - HUMAN_INPUT_FORM_FILLED = "human_input_form_filled" - HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout" class AppQueueEvent(BaseModel): @@ -266,8 +261,6 @@ class QueueWorkflowStartedEvent(AppQueueEvent): """QueueWorkflowStartedEvent entity.""" event: QueueEvent = QueueEvent.WORKFLOW_STARTED - # Always present; mirrors GraphRunStartedEvent.reason for downstream consumers. - reason: WorkflowStartReason = WorkflowStartReason.INITIAL class QueueWorkflowSucceededEvent(AppQueueEvent): @@ -491,35 +484,6 @@ class QueueStopEvent(AppQueueEvent): return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.") -class QueueHumanInputFormFilledEvent(AppQueueEvent): - """ - QueueHumanInputFormFilledEvent entity - """ - - event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_FILLED - - node_execution_id: str - node_id: str - node_type: NodeType - node_title: str - rendered_content: str - action_id: str - action_text: str - - -class QueueHumanInputFormTimeoutEvent(AppQueueEvent): - """ - QueueHumanInputFormTimeoutEvent entity - """ - - event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_TIMEOUT - - node_id: str - node_type: NodeType - node_title: str - expiration_time: datetime - - class QueueMessage(BaseModel): """ QueueMessage abstract entity @@ -545,14 +509,3 @@ class WorkflowQueueMessage(QueueMessage): """ pass - - -class QueueWorkflowPausedEvent(AppQueueEvent): - """ - QueueWorkflowPausedEvent entity - """ - - event: QueueEvent = QueueEvent.PAUSE - reasons: Sequence[PauseReason] = Field(default_factory=list) - outputs: Mapping[str, object] = Field(default_factory=dict) - paused_nodes: Sequence[str] = Field(default_factory=list) diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 3f38904d2f..79a5e657b3 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -7,9 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities import AgentNodeStrategyInit -from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.nodes.human_input.entities import FormInput, UserAction class AnnotationReplyAccount(BaseModel): @@ -71,7 +69,6 @@ class StreamEvent(StrEnum): AGENT_THOUGHT = "agent_thought" AGENT_MESSAGE = "agent_message" WORKFLOW_STARTED = "workflow_started" - WORKFLOW_PAUSED = "workflow_paused" WORKFLOW_FINISHED = "workflow_finished" NODE_STARTED = "node_started" NODE_FINISHED = "node_finished" @@ -85,9 +82,6 @@ class StreamEvent(StrEnum): TEXT_CHUNK = "text_chunk" TEXT_REPLACE = "text_replace" AGENT_LOG = "agent_log" - HUMAN_INPUT_REQUIRED = "human_input_required" - HUMAN_INPUT_FORM_FILLED = "human_input_form_filled" - HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout" class StreamResponse(BaseModel): @@ -211,8 +205,6 @@ class WorkflowStartStreamResponse(StreamResponse): workflow_id: str inputs: Mapping[str, Any] created_at: int - # Always present; mirrors QueueWorkflowStartedEvent.reason for SSE clients. - reason: WorkflowStartReason = WorkflowStartReason.INITIAL event: StreamEvent = StreamEvent.WORKFLOW_STARTED workflow_run_id: str @@ -239,7 +231,7 @@ class WorkflowFinishStreamResponse(StreamResponse): total_steps: int created_by: Mapping[str, object] = Field(default_factory=dict) created_at: int - finished_at: int | None + finished_at: int exceptions_count: int | None = 0 files: Sequence[Mapping[str, Any]] | None = [] @@ -248,85 +240,6 @@ class WorkflowFinishStreamResponse(StreamResponse): data: Data -class WorkflowPauseStreamResponse(StreamResponse): - """ - WorkflowPauseStreamResponse entity - """ - - class Data(BaseModel): - """ - Data entity - """ - - workflow_run_id: str - paused_nodes: Sequence[str] = Field(default_factory=list) - outputs: Mapping[str, Any] = Field(default_factory=dict) - reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list) - status: str - created_at: int - elapsed_time: float - total_tokens: int - total_steps: int - - event: StreamEvent = StreamEvent.WORKFLOW_PAUSED - workflow_run_id: str - data: Data - - -class HumanInputRequiredResponse(StreamResponse): - class Data(BaseModel): - """ - Data entity - """ - - form_id: str - node_id: str - node_title: str - form_content: str - inputs: Sequence[FormInput] = Field(default_factory=list) - actions: Sequence[UserAction] = Field(default_factory=list) - display_in_ui: bool = False - form_token: str | None = None - resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) - expiration_time: int = Field(..., description="Unix timestamp in seconds") - - event: StreamEvent = StreamEvent.HUMAN_INPUT_REQUIRED - workflow_run_id: str - data: Data - - -class HumanInputFormFilledResponse(StreamResponse): - class Data(BaseModel): - """ - Data entity - """ - - node_id: str - node_title: str - rendered_content: str - action_id: str - action_text: str - - event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_FILLED - workflow_run_id: str - data: Data - - -class HumanInputFormTimeoutResponse(StreamResponse): - class Data(BaseModel): - """ - Data entity - """ - - node_id: str - node_title: str - expiration_time: int - - event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_TIMEOUT - workflow_run_id: str - data: Data - - class NodeStartStreamResponse(StreamResponse): """ NodeStartStreamResponse entity @@ -813,7 +726,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): total_tokens: int total_steps: int created_at: int - finished_at: int | None + finished_at: int workflow_run_id: str data: Data diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 2ca1275a8a..565905be0d 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -1,4 +1,3 @@ -import contextlib import logging import time import uuid @@ -104,14 +103,6 @@ class RateLimit: ) -@contextlib.contextmanager -def rate_limit_context(rate_limit: RateLimit, request_id: str | None): - request_id = rate_limit.enter(request_id) - yield - if request_id is not None: - rate_limit.exit(request_id) - - class RateLimitGenerator: def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str): self.rate_limit = rate_limit diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 1c267091a4..bf76ae8178 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Annotated, Literal, Self, TypeAlias from pydantic import BaseModel, Field @@ -53,14 +52,6 @@ class WorkflowResumptionContext(BaseModel): return self.generate_entity.entity -@dataclass(frozen=True) -class PauseStateLayerConfig: - """Configuration container for instantiating pause persistence layers.""" - - session_factory: Engine | sessionmaker[Session] - state_owner_user_id: str - - class PauseStatePersistenceLayer(GraphEngineLayer): def __init__( self, diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index d682083f34..2d4ee08daf 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -82,11 +82,10 @@ class MessageCycleManager: if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): return None - is_first_message = self._application_generate_entity.is_new_conversation + is_first_message = self._application_generate_entity.conversation_id is None extras = self._application_generate_entity.extras auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True) - thread: Thread | None = None if auto_generate_conversation_name and is_first_message: # start generate thread # time.sleep not block other logic @@ -102,10 +101,9 @@ class MessageCycleManager: thread.daemon = True thread.start() - if is_first_message: - self._application_generate_entity.is_new_conversation = False + return thread - return thread + return None def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): with flask_app.app_context(): diff --git a/api/core/entities/execution_extra_content.py b/api/core/entities/execution_extra_content.py deleted file mode 100644 index 46006f4381..0000000000 --- a/api/core/entities/execution_extra_content.py +++ /dev/null @@ -1,54 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from typing import Any, TypeAlias - -from pydantic import BaseModel, ConfigDict, Field - -from core.workflow.nodes.human_input.entities import FormInput, UserAction -from models.execution_extra_content import ExecutionContentType - - -class HumanInputFormDefinition(BaseModel): - model_config = ConfigDict(frozen=True) - - form_id: str - node_id: str - node_title: str - form_content: str - inputs: Sequence[FormInput] = Field(default_factory=list) - actions: Sequence[UserAction] = Field(default_factory=list) - display_in_ui: bool = False - form_token: str | None = None - resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) - expiration_time: int - - -class HumanInputFormSubmissionData(BaseModel): - model_config = ConfigDict(frozen=True) - - node_id: str - node_title: str - rendered_content: str - action_id: str - action_text: str - - -class HumanInputContent(BaseModel): - model_config = ConfigDict(frozen=True) - - workflow_run_id: str - submitted: bool - form_definition: HumanInputFormDefinition | None = None - form_submission_data: HumanInputFormSubmissionData | None = None - type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT) - - -ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent - -__all__ = [ - "ExecutionExtraContentDomainModel", - "HumanInputContent", - "HumanInputFormDefinition", - "HumanInputFormSubmissionData", -] diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 8a26b2e91b..e8d41b9387 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -28,8 +28,8 @@ from core.model_runtime.entities.provider_entities import ( ) from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from models.engine import db from models.provider import ( LoadBalancingModelConfig, Provider, diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 549e428f88..84f5bf5512 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -15,7 +15,10 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token -from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum +from core.ops.entities.config_entity import ( + OPS_FILE_PATH, + TracingProviderEnum, +) from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -28,8 +31,8 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.utils import get_message_data +from extensions.ext_database import db from extensions.ext_storage import storage -from models.engine import db from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig from models.workflow import WorkflowAppLog from tasks.ops_trace_task import process_trace_tasks @@ -466,8 +469,6 @@ class TraceTask: @classmethod def _get_workflow_run_repo(cls): - from repositories.factory import DifyAPIRepositoryFactory - if cls._workflow_run_repo is None: with cls._repo_lock: if cls._workflow_run_repo is None: diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index a5196d66c0..631e3b77b2 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -5,7 +5,7 @@ from urllib.parse import urlparse from sqlalchemy import select -from models.engine import db +from extensions.ext_database import db from models.model import Message diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 3c5df2b905..32e8ef385c 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -1,4 +1,3 @@ -import uuid from collections.abc import Generator, Mapping from typing import Union @@ -12,7 +11,6 @@ from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.completion.app_generator import CompletionAppGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from extensions.ext_database import db from models import Account @@ -103,11 +101,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): if not workflow: raise ValueError("unexpected app type") - pause_config = PauseStateLayerConfig( - session_factory=db.engine, - state_owner_user_id=workflow.created_by, - ) - return AdvancedChatAppGenerator().generate( app_model=app, workflow=workflow, @@ -119,9 +112,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): "conversation_id": conversation_id, }, invoke_from=InvokeFrom.SERVICE_API, - workflow_run_id=str(uuid.uuid4()), streaming=stream, - pause_state_config=pause_config, ) elif app.mode == AppMode.AGENT_CHAT: return AgentChatAppGenerator().generate( @@ -168,11 +159,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): if not workflow: raise ValueError("unexpected app type") - pause_config = PauseStateLayerConfig( - session_factory=db.engine, - state_owner_user_id=workflow.created_by, - ) - return WorkflowAppGenerator().generate( app_model=app, workflow=workflow, @@ -181,7 +167,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): invoke_from=InvokeFrom.SERVICE_API, streaming=stream, call_depth=1, - pause_state_config=pause_config, ) @classmethod diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index 6f2826f634..d83823d7b9 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -1,18 +1,19 @@ -"""Repository implementations for data access.""" +""" +Repository implementations for data access. -from __future__ import annotations +This package contains concrete implementations of the repository interfaces +defined in the core.workflow.repository package. +""" -from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository -from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from .factory import DifyCoreRepositoryFactory, RepositoryImportError -from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository +from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError +from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository __all__ = [ "CeleryWorkflowExecutionRepository", "CeleryWorkflowNodeExecutionRepository", "DifyCoreRepositoryFactory", "RepositoryImportError", - "SQLAlchemyWorkflowExecutionRepository", "SQLAlchemyWorkflowNodeExecutionRepository", ] diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py deleted file mode 100644 index 0e04c56e0e..0000000000 --- a/api/core/repositories/human_input_repository.py +++ /dev/null @@ -1,553 +0,0 @@ -import dataclasses -import json -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any - -from sqlalchemy import Engine, select -from sqlalchemy.orm import Session, selectinload, sessionmaker - -from core.workflow.nodes.human_input.entities import ( - DeliveryChannelConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, - FormDefinition, - HumanInputNodeData, - MemberRecipient, - WebAppDeliveryMethod, -) -from core.workflow.nodes.human_input.enums import ( - DeliveryMethodType, - HumanInputFormKind, - HumanInputFormStatus, -) -from core.workflow.repositories.human_input_form_repository import ( - FormCreateParams, - FormNotFoundError, - HumanInputFormEntity, - HumanInputFormRecipientEntity, -) -from libs.datetime_utils import naive_utc_now -from libs.uuid_utils import uuidv7 -from models.account import Account, TenantAccountJoin -from models.human_input import ( - BackstageRecipientPayload, - ConsoleDeliveryPayload, - ConsoleRecipientPayload, - EmailExternalRecipientPayload, - EmailMemberRecipientPayload, - HumanInputDelivery, - HumanInputForm, - HumanInputFormRecipient, - RecipientType, - StandaloneWebAppRecipientPayload, -) - - -@dataclasses.dataclass(frozen=True) -class _DeliveryAndRecipients: - delivery: HumanInputDelivery - recipients: Sequence[HumanInputFormRecipient] - - -@dataclasses.dataclass(frozen=True) -class _WorkspaceMemberInfo: - user_id: str - email: str - - -class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity): - def __init__(self, recipient_model: HumanInputFormRecipient): - self._recipient_model = recipient_model - - @property - def id(self) -> str: - return self._recipient_model.id - - @property - def token(self) -> str: - if self._recipient_model.access_token is None: - raise AssertionError(f"access_token should not be None for recipient {self._recipient_model.id}") - return self._recipient_model.access_token - - -class _HumanInputFormEntityImpl(HumanInputFormEntity): - def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]): - self._form_model = form_model - self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models] - self._web_app_recipient = next( - ( - recipient - for recipient in recipient_models - if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP - ), - None, - ) - self._console_recipient = next( - (recipient for recipient in recipient_models if recipient.recipient_type == RecipientType.CONSOLE), - None, - ) - self._submitted_data: Mapping[str, Any] | None = ( - json.loads(form_model.submitted_data) if form_model.submitted_data is not None else None - ) - - @property - def id(self) -> str: - return self._form_model.id - - @property - def web_app_token(self): - if self._console_recipient is not None: - return self._console_recipient.access_token - if self._web_app_recipient is None: - return None - return self._web_app_recipient.access_token - - @property - def recipients(self) -> list[HumanInputFormRecipientEntity]: - return list(self._recipients) - - @property - def rendered_content(self) -> str: - return self._form_model.rendered_content - - @property - def selected_action_id(self) -> str | None: - return self._form_model.selected_action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self._submitted_data - - @property - def submitted(self) -> bool: - return self._form_model.submitted_at is not None - - @property - def status(self) -> HumanInputFormStatus: - return self._form_model.status - - @property - def expiration_time(self) -> datetime: - return self._form_model.expiration_time - - -@dataclasses.dataclass(frozen=True) -class HumanInputFormRecord: - form_id: str - workflow_run_id: str | None - node_id: str - tenant_id: str - app_id: str - form_kind: HumanInputFormKind - definition: FormDefinition - rendered_content: str - created_at: datetime - expiration_time: datetime - status: HumanInputFormStatus - selected_action_id: str | None - submitted_data: Mapping[str, Any] | None - submitted_at: datetime | None - submission_user_id: str | None - submission_end_user_id: str | None - completed_by_recipient_id: str | None - recipient_id: str | None - recipient_type: RecipientType | None - access_token: str | None - - @property - def submitted(self) -> bool: - return self.submitted_at is not None - - @classmethod - def from_models( - cls, form_model: HumanInputForm, recipient_model: HumanInputFormRecipient | None - ) -> "HumanInputFormRecord": - definition_payload = json.loads(form_model.form_definition) - if "expiration_time" not in definition_payload: - definition_payload["expiration_time"] = form_model.expiration_time - return cls( - form_id=form_model.id, - workflow_run_id=form_model.workflow_run_id, - node_id=form_model.node_id, - tenant_id=form_model.tenant_id, - app_id=form_model.app_id, - form_kind=form_model.form_kind, - definition=FormDefinition.model_validate(definition_payload), - rendered_content=form_model.rendered_content, - created_at=form_model.created_at, - expiration_time=form_model.expiration_time, - status=form_model.status, - selected_action_id=form_model.selected_action_id, - submitted_data=json.loads(form_model.submitted_data) if form_model.submitted_data else None, - submitted_at=form_model.submitted_at, - submission_user_id=form_model.submission_user_id, - submission_end_user_id=form_model.submission_end_user_id, - completed_by_recipient_id=form_model.completed_by_recipient_id, - recipient_id=recipient_model.id if recipient_model else None, - recipient_type=recipient_model.recipient_type if recipient_model else None, - access_token=recipient_model.access_token if recipient_model else None, - ) - - -class _InvalidTimeoutStatusError(ValueError): - pass - - -class HumanInputFormRepositoryImpl: - def __init__( - self, - session_factory: sessionmaker | Engine, - tenant_id: str, - ): - if isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory - self._tenant_id = tenant_id - - def _delivery_method_to_model( - self, - session: Session, - form_id: str, - delivery_method: DeliveryChannelConfig, - ) -> _DeliveryAndRecipients: - delivery_id = str(uuidv7()) - delivery_model = HumanInputDelivery( - id=delivery_id, - form_id=form_id, - delivery_method_type=delivery_method.type, - delivery_config_id=delivery_method.id, - channel_payload=delivery_method.model_dump_json(), - ) - recipients: list[HumanInputFormRecipient] = [] - if isinstance(delivery_method, WebAppDeliveryMethod): - recipient_model = HumanInputFormRecipient( - form_id=form_id, - delivery_id=delivery_id, - recipient_type=RecipientType.STANDALONE_WEB_APP, - recipient_payload=StandaloneWebAppRecipientPayload().model_dump_json(), - ) - recipients.append(recipient_model) - elif isinstance(delivery_method, EmailDeliveryMethod): - email_recipients_config = delivery_method.config.recipients - recipients.extend( - self._build_email_recipients( - session=session, - form_id=form_id, - delivery_id=delivery_id, - recipients_config=email_recipients_config, - ) - ) - - return _DeliveryAndRecipients(delivery=delivery_model, recipients=recipients) - - def _build_email_recipients( - self, - session: Session, - form_id: str, - delivery_id: str, - recipients_config: EmailRecipients, - ) -> list[HumanInputFormRecipient]: - member_user_ids = [ - recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient) - ] - external_emails = [ - recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient) - ] - if recipients_config.whole_workspace: - members = self._query_all_workspace_members(session=session) - else: - members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids) - - return self._create_email_recipients_from_resolved( - form_id=form_id, - delivery_id=delivery_id, - members=members, - external_emails=external_emails, - ) - - @staticmethod - def _create_email_recipients_from_resolved( - *, - form_id: str, - delivery_id: str, - members: Sequence[_WorkspaceMemberInfo], - external_emails: Sequence[str], - ) -> list[HumanInputFormRecipient]: - recipient_models: list[HumanInputFormRecipient] = [] - seen_emails: set[str] = set() - - for member in members: - if not member.email: - continue - if member.email in seen_emails: - continue - seen_emails.add(member.email) - payload = EmailMemberRecipientPayload(user_id=member.user_id, email=member.email) - recipient_models.append( - HumanInputFormRecipient.new( - form_id=form_id, - delivery_id=delivery_id, - payload=payload, - ) - ) - - for email in external_emails: - if not email: - continue - if email in seen_emails: - continue - seen_emails.add(email) - recipient_models.append( - HumanInputFormRecipient.new( - form_id=form_id, - delivery_id=delivery_id, - payload=EmailExternalRecipientPayload(email=email), - ) - ) - - return recipient_models - - def _query_all_workspace_members( - self, - session: Session, - ) -> list[_WorkspaceMemberInfo]: - stmt = ( - select(Account.id, Account.email) - .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) - .where(TenantAccountJoin.tenant_id == self._tenant_id) - ) - rows = session.execute(stmt).all() - return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows] - - def _query_workspace_members_by_ids( - self, - session: Session, - restrict_to_user_ids: Sequence[str], - ) -> list[_WorkspaceMemberInfo]: - unique_ids = {user_id for user_id in restrict_to_user_ids if user_id} - if not unique_ids: - return [] - - stmt = ( - select(Account.id, Account.email) - .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) - .where(TenantAccountJoin.tenant_id == self._tenant_id) - ) - stmt = stmt.where(Account.id.in_(unique_ids)) - - rows = session.execute(stmt).all() - return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows] - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - form_config: HumanInputNodeData = params.form_config - - with self._session_factory(expire_on_commit=False) as session, session.begin(): - # Generate unique form ID - form_id = str(uuidv7()) - start_time = naive_utc_now() - node_expiration = form_config.expiration_time(start_time) - form_definition = FormDefinition( - form_content=form_config.form_content, - inputs=form_config.inputs, - user_actions=form_config.user_actions, - rendered_content=params.rendered_content, - expiration_time=node_expiration, - default_values=dict(params.resolved_default_values), - display_in_ui=params.display_in_ui, - node_title=form_config.title, - ) - form_model = HumanInputForm( - id=form_id, - tenant_id=self._tenant_id, - app_id=params.app_id, - workflow_run_id=params.workflow_execution_id, - form_kind=params.form_kind, - node_id=params.node_id, - form_definition=form_definition.model_dump_json(), - rendered_content=params.rendered_content, - expiration_time=node_expiration, - created_at=start_time, - ) - session.add(form_model) - recipient_models: list[HumanInputFormRecipient] = [] - for delivery in params.delivery_methods: - delivery_and_recipients = self._delivery_method_to_model( - session=session, - form_id=form_id, - delivery_method=delivery, - ) - session.add(delivery_and_recipients.delivery) - session.add_all(delivery_and_recipients.recipients) - recipient_models.extend(delivery_and_recipients.recipients) - if params.console_recipient_required and not any( - recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models - ): - console_delivery_id = str(uuidv7()) - console_delivery = HumanInputDelivery( - id=console_delivery_id, - form_id=form_id, - delivery_method_type=DeliveryMethodType.WEBAPP, - delivery_config_id=None, - channel_payload=ConsoleDeliveryPayload().model_dump_json(), - ) - console_recipient = HumanInputFormRecipient( - form_id=form_id, - delivery_id=console_delivery_id, - recipient_type=RecipientType.CONSOLE, - recipient_payload=ConsoleRecipientPayload( - account_id=params.console_creator_account_id, - ).model_dump_json(), - ) - session.add(console_delivery) - session.add(console_recipient) - recipient_models.append(console_recipient) - if params.backstage_recipient_required and not any( - recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models - ): - backstage_delivery_id = str(uuidv7()) - backstage_delivery = HumanInputDelivery( - id=backstage_delivery_id, - form_id=form_id, - delivery_method_type=DeliveryMethodType.WEBAPP, - delivery_config_id=None, - channel_payload=ConsoleDeliveryPayload().model_dump_json(), - ) - backstage_recipient = HumanInputFormRecipient( - form_id=form_id, - delivery_id=backstage_delivery_id, - recipient_type=RecipientType.BACKSTAGE, - recipient_payload=BackstageRecipientPayload( - account_id=params.console_creator_account_id, - ).model_dump_json(), - ) - session.add(backstage_delivery) - session.add(backstage_recipient) - recipient_models.append(backstage_recipient) - session.flush() - - return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models) - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - form_query = select(HumanInputForm).where( - HumanInputForm.workflow_run_id == workflow_execution_id, - HumanInputForm.node_id == node_id, - HumanInputForm.tenant_id == self._tenant_id, - ) - with self._session_factory(expire_on_commit=False) as session: - form_model: HumanInputForm | None = session.scalars(form_query).first() - if form_model is None: - return None - - recipient_query = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_model.id) - recipient_models = session.scalars(recipient_query).all() - return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models) - - -class HumanInputFormSubmissionRepository: - """Repository for fetching and submitting human input forms.""" - - def __init__(self, session_factory: sessionmaker | Engine): - if isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory - - def get_by_token(self, form_token: str) -> HumanInputFormRecord | None: - query = ( - select(HumanInputFormRecipient) - .options(selectinload(HumanInputFormRecipient.form)) - .where(HumanInputFormRecipient.access_token == form_token) - ) - with self._session_factory(expire_on_commit=False) as session: - recipient_model = session.scalars(query).first() - if recipient_model is None or recipient_model.form is None: - return None - return HumanInputFormRecord.from_models(recipient_model.form, recipient_model) - - def get_by_form_id_and_recipient_type( - self, - form_id: str, - recipient_type: RecipientType, - ) -> HumanInputFormRecord | None: - query = ( - select(HumanInputFormRecipient) - .options(selectinload(HumanInputFormRecipient.form)) - .where( - HumanInputFormRecipient.form_id == form_id, - HumanInputFormRecipient.recipient_type == recipient_type, - ) - ) - with self._session_factory(expire_on_commit=False) as session: - recipient_model = session.scalars(query).first() - if recipient_model is None or recipient_model.form is None: - return None - return HumanInputFormRecord.from_models(recipient_model.form, recipient_model) - - def mark_submitted( - self, - *, - form_id: str, - recipient_id: str | None, - selected_action_id: str, - form_data: Mapping[str, Any], - submission_user_id: str | None, - submission_end_user_id: str | None, - ) -> HumanInputFormRecord: - with self._session_factory(expire_on_commit=False) as session, session.begin(): - form_model = session.get(HumanInputForm, form_id) - if form_model is None: - raise FormNotFoundError(f"form not found, id={form_id}") - - recipient_model = session.get(HumanInputFormRecipient, recipient_id) if recipient_id else None - - form_model.selected_action_id = selected_action_id - form_model.submitted_data = json.dumps(form_data) - form_model.submitted_at = naive_utc_now() - form_model.status = HumanInputFormStatus.SUBMITTED - form_model.submission_user_id = submission_user_id - form_model.submission_end_user_id = submission_end_user_id - form_model.completed_by_recipient_id = recipient_id - - session.add(form_model) - session.flush() - session.refresh(form_model) - if recipient_model is not None: - session.refresh(recipient_model) - - return HumanInputFormRecord.from_models(form_model, recipient_model) - - def mark_timeout( - self, - *, - form_id: str, - timeout_status: HumanInputFormStatus, - reason: str | None = None, - ) -> HumanInputFormRecord: - with self._session_factory(expire_on_commit=False) as session, session.begin(): - form_model = session.get(HumanInputForm, form_id) - if form_model is None: - raise FormNotFoundError(f"form not found, id={form_id}") - - if timeout_status not in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}: - raise _InvalidTimeoutStatusError(f"invalid timeout status: {timeout_status}") - - # already handled or submitted - if form_model.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}: - return HumanInputFormRecord.from_models(form_model, None) - - if form_model.submitted_at is not None or form_model.status == HumanInputFormStatus.SUBMITTED: - raise FormNotFoundError(f"form already submitted, id={form_id}") - - form_model.status = timeout_status - form_model.selected_action_id = None - form_model.submitted_data = None - form_model.submission_user_id = None - form_model.submission_end_user_id = None - form_model.completed_by_recipient_id = None - # Reason is recorded in status/error downstream; not stored on form. - session.add(form_model) - session.flush() - session.refresh(form_model) - - return HumanInputFormRecord.from_models(form_model, None) diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 324dd059d1..4436773d25 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -488,7 +488,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, WorkflowNodeExecutionModel.tenant_id == self._tenant_id, WorkflowNodeExecutionModel.triggered_from == triggered_from, - WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED, ) if self._app_id: diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index 4c3efd6ff9..e4afe24426 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -1,5 +1,4 @@ from core.tools.entities.tool_entities import ToolInvokeMeta -from libs.exception import BaseHTTPException class ToolProviderNotFoundError(ValueError): @@ -38,12 +37,6 @@ class ToolCredentialPolicyViolationError(ValueError): pass -class WorkflowToolHumanInputNotSupportedError(BaseHTTPException): - error_code = "workflow_tool_human_input_not_supported" - description = "Workflow with Human Input nodes cannot be published as a workflow tool." - code = 400 - - class ToolEngineInvokeError(Exception): meta: ToolInvokeMeta diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 8588ccc718..188da0c32d 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -3,8 +3,6 @@ from typing import Any from core.app.app_config.entities import VariableEntity from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration -from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from core.workflow.enums import NodeType from core.workflow.nodes.base.entities import OutputVariableEntity @@ -52,13 +50,6 @@ class WorkflowToolConfigurationUtils: return [outputs_by_variable[variable] for variable in variable_order] - @classmethod - def ensure_no_human_input_nodes(cls, graph: Mapping[str, Any]) -> None: - nodes = graph.get("nodes", []) - for node in nodes: - if node.get("data", {}).get("type") == NodeType.HUMAN_INPUT: - raise WorkflowToolHumanInputNotSupportedError() - @classmethod def check_is_synced( cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 01fa5de31e..9c1ceff145 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -98,10 +98,6 @@ class WorkflowTool(Tool): invoke_from=self.runtime.invoke_from, streaming=False, call_depth=self.workflow_call_depth + 1, - # NOTE(QuantumGhost): We explicitly set `pause_state_config` to `None` - # because workflow pausing mechanisms (such as HumanInput) are not - # supported within WorkflowTool execution context. - pause_state_config=None, ) assert isinstance(result, dict) data = result.get("data", {}) diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py index e73c38c1d3..be70e467a0 100644 --- a/api/core/workflow/entities/__init__.py +++ b/api/core/workflow/entities/__init__.py @@ -2,12 +2,10 @@ from .agent import AgentNodeStrategyInit from .graph_init_params import GraphInitParams from .workflow_execution import WorkflowExecution from .workflow_node_execution import WorkflowNodeExecution -from .workflow_start_reason import WorkflowStartReason __all__ = [ "AgentNodeStrategyInit", "GraphInitParams", "WorkflowExecution", "WorkflowNodeExecution", - "WorkflowStartReason", ] diff --git a/api/core/workflow/entities/graph_init_params.py b/api/core/workflow/entities/graph_init_params.py index ff224a28d1..7bf25b9f43 100644 --- a/api/core/workflow/entities/graph_init_params.py +++ b/api/core/workflow/entities/graph_init_params.py @@ -5,16 +5,6 @@ from pydantic import BaseModel, Field class GraphInitParams(BaseModel): - """GraphInitParams encapsulates the configurations and contextual information - that remain constant throughout a single execution of the graph engine. - - A single execution is defined as follows: as long as the execution has not reached - its conclusion, it is considered one execution. For instance, if a workflow is suspended - and later resumed, it is still regarded as a single execution, not two. - - For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`. - """ - # init params tenant_id: str = Field(..., description="tenant / workspace id") app_id: str = Field(..., description="app id") diff --git a/api/core/workflow/entities/pause_reason.py b/api/core/workflow/entities/pause_reason.py index 147f56e8be..c6655b7eab 100644 --- a/api/core/workflow/entities/pause_reason.py +++ b/api/core/workflow/entities/pause_reason.py @@ -1,11 +1,8 @@ -from collections.abc import Mapping from enum import StrEnum, auto -from typing import Annotated, Any, Literal, TypeAlias +from typing import Annotated, Literal, TypeAlias from pydantic import BaseModel, Field -from core.workflow.nodes.human_input.entities import FormInput, UserAction - class PauseReasonType(StrEnum): HUMAN_INPUT_REQUIRED = auto() @@ -14,31 +11,10 @@ class PauseReasonType(StrEnum): class HumanInputRequired(BaseModel): TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED + form_id: str - form_content: str - inputs: list[FormInput] = Field(default_factory=list) - actions: list[UserAction] = Field(default_factory=list) - display_in_ui: bool = False + # The identifier of the human input node causing the pause. node_id: str - node_title: str - - # The `resolved_default_values` stores the resolved values of variable defaults. It's a mapping from - # `output_variable_name` to their resolved values. - # - # For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its - # selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable - # `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The - # `resolved_default_values` is `{"name": "John"}`. - # - # Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`. - resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) - - # The `form_token` is the token used to submit the form via UI surfaces. It corresponds to - # `HumanInputFormRecipient.access_token`. - # - # This field is `None` if webapp delivery is not set and not - # in orchestrating mode. - form_token: str | None = None class SchedulingPause(BaseModel): diff --git a/api/core/workflow/entities/workflow_start_reason.py b/api/core/workflow/entities/workflow_start_reason.py deleted file mode 100644 index df0f75383b..0000000000 --- a/api/core/workflow/entities/workflow_start_reason.py +++ /dev/null @@ -1,8 +0,0 @@ -from enum import StrEnum - - -class WorkflowStartReason(StrEnum): - """Reason for workflow start events across graph/queue/SSE layers.""" - - INITIAL = "initial" # First start of a workflow run. - RESUMPTION = "resumption" # Start triggered after resuming a paused run. diff --git a/api/core/workflow/graph_engine/_engine_utils.py b/api/core/workflow/graph_engine/_engine_utils.py deleted file mode 100644 index 28898268fe..0000000000 --- a/api/core/workflow/graph_engine/_engine_utils.py +++ /dev/null @@ -1,15 +0,0 @@ -import time - - -def get_timestamp() -> float: - """Retrieve a timestamp as a float point numer representing the number of seconds - since the Unix epoch. - - This function is primarily used to measure the execution time of the workflow engine. - Since workflow execution may be paused and resumed on a different machine, - `time.perf_counter` cannot be used as it is inconsistent across machines. - - To address this, the function uses the wall clock as the time source. - However, it assumes that the clocks of all servers are properly synchronized. - """ - return round(time.time()) diff --git a/api/core/workflow/graph_engine/config.py b/api/core/workflow/graph_engine/config.py index d56a69cee0..10dbbd7535 100644 --- a/api/core/workflow/graph_engine/config.py +++ b/api/core/workflow/graph_engine/config.py @@ -2,14 +2,12 @@ GraphEngine configuration models. """ -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel class GraphEngineConfig(BaseModel): """Configuration for GraphEngine worker pool scaling.""" - model_config = ConfigDict(frozen=True) - min_workers: int = 1 max_workers: int = 5 scale_up_threshold: int = 3 diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/core/workflow/graph_engine/event_management/event_handlers.py index 98a0702e1c..5b0f56e59d 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/core/workflow/graph_engine/event_management/event_handlers.py @@ -192,13 +192,9 @@ class EventHandler: self._event_collector.collect(edge_event) # Enqueue ready nodes - if self._graph_execution.is_paused: - for node_id in ready_nodes: - self._graph_runtime_state.register_deferred_node(node_id) - else: - for node_id in ready_nodes: - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) + for node_id in ready_nodes: + self._state_manager.enqueue_node(node_id) + self._state_manager.start_execution(node_id) # Update execution tracking self._state_manager.finish_execution(event.node_id) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index ac9e00e29e..0b359a2392 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -14,7 +14,6 @@ from collections.abc import Generator from typing import TYPE_CHECKING, cast, final from core.workflow.context import capture_current_context -from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import NodeExecutionType from core.workflow.graph import Graph from core.workflow.graph_events import ( @@ -57,9 +56,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -_DEFAULT_CONFIG = GraphEngineConfig() - - @final class GraphEngine: """ @@ -75,7 +71,7 @@ class GraphEngine: graph: Graph, graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel, - config: GraphEngineConfig = _DEFAULT_CONFIG, + config: GraphEngineConfig, ) -> None: """Initialize the graph engine with all subsystems and dependencies.""" # stop event @@ -239,9 +235,7 @@ class GraphEngine: self._graph_execution.paused = False self._graph_execution.pause_reasons = [] - start_event = GraphRunStartedEvent( - reason=WorkflowStartReason.RESUMPTION if is_resume else WorkflowStartReason.INITIAL, - ) + start_event = GraphRunStartedEvent() self._event_manager.notify_layers(start_event) yield start_event @@ -310,17 +304,15 @@ class GraphEngine: for layer in self._layers: try: layer.on_graph_start() - except Exception: - logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__) + except Exception as e: + logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e) def _start_execution(self, *, resume: bool = False) -> None: """Start execution subsystems.""" self._stop_event.clear() paused_nodes: list[str] = [] - deferred_nodes: list[str] = [] if resume: paused_nodes = self._graph_runtime_state.consume_paused_nodes() - deferred_nodes = self._graph_runtime_state.consume_deferred_nodes() # Start worker pool (it calculates initial workers internally) self._worker_pool.start() @@ -336,11 +328,7 @@ class GraphEngine: self._state_manager.enqueue_node(root_node.id) self._state_manager.start_execution(root_node.id) else: - seen_nodes: set[str] = set() - for node_id in paused_nodes + deferred_nodes: - if node_id in seen_nodes: - continue - seen_nodes.add(node_id) + for node_id in paused_nodes: self._state_manager.enqueue_node(node_id) self._state_manager.start_execution(node_id) @@ -358,8 +346,8 @@ class GraphEngine: for layer in self._layers: try: layer.on_graph_end(self._graph_execution.error) - except Exception: - logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__) + except Exception as e: + logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e) # Public property accessors for attributes that need external access @property diff --git a/api/core/workflow/graph_engine/graph_state_manager.py b/api/core/workflow/graph_engine/graph_state_manager.py index d9773645c3..22a3a826fc 100644 --- a/api/core/workflow/graph_engine/graph_state_manager.py +++ b/api/core/workflow/graph_engine/graph_state_manager.py @@ -224,8 +224,6 @@ class GraphStateManager: Returns: Number of executing nodes """ - # This count is a best-effort snapshot and can change concurrently. - # Only use it for pause-drain checks where scheduling is already frozen. with self._lock: return len(self._executing_nodes) diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py index d40d15c545..27439a2412 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py @@ -83,12 +83,12 @@ class Dispatcher: """Main dispatcher loop.""" try: self._process_commands() - paused = False while not self._stop_event.is_set(): - if self._execution_coordinator.aborted or self._execution_coordinator.execution_complete: - break - if self._execution_coordinator.paused: - paused = True + if ( + self._execution_coordinator.aborted + or self._execution_coordinator.paused + or self._execution_coordinator.execution_complete + ): break self._execution_coordinator.check_scaling() @@ -101,10 +101,13 @@ class Dispatcher: time.sleep(0.1) self._process_commands() - if paused: - self._drain_events_until_idle() - else: - self._drain_event_queue() + while True: + try: + event = self._event_queue.get(block=False) + self._event_handler.dispatch(event) + self._event_queue.task_done() + except queue.Empty: + break except Exception as e: logger.exception("Dispatcher error") @@ -119,24 +122,3 @@ class Dispatcher: def _process_commands(self, event: GraphNodeEventBase | None = None): if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS): self._execution_coordinator.process_commands() - - def _drain_event_queue(self) -> None: - while True: - try: - event = self._event_queue.get(block=False) - self._event_handler.dispatch(event) - self._event_queue.task_done() - except queue.Empty: - break - - def _drain_events_until_idle(self) -> None: - while not self._stop_event.is_set(): - try: - event = self._event_queue.get(timeout=0.1) - self._event_handler.dispatch(event) - self._event_queue.task_done() - self._process_commands(event) - except queue.Empty: - if not self._execution_coordinator.has_executing_nodes(): - break - self._drain_event_queue() diff --git a/api/core/workflow/graph_engine/orchestration/execution_coordinator.py b/api/core/workflow/graph_engine/orchestration/execution_coordinator.py index 0f8550eb12..e8e8f9f16c 100644 --- a/api/core/workflow/graph_engine/orchestration/execution_coordinator.py +++ b/api/core/workflow/graph_engine/orchestration/execution_coordinator.py @@ -94,11 +94,3 @@ class ExecutionCoordinator: self._worker_pool.stop() self._state_manager.clear_executing() - - def has_executing_nodes(self) -> bool: - """Return True if any nodes are currently marked as executing.""" - # This check is only safe once execution has already paused. - # Before pause, executing state can change concurrently, which makes the result unreliable. - if not self._graph_execution.is_paused: - raise AssertionError("has_executing_nodes should only be called after execution is paused") - return self._state_manager.get_executing_count() > 0 diff --git a/api/core/workflow/graph_events/__init__.py b/api/core/workflow/graph_events/__init__.py index 56ea642092..2b6ee4ec1c 100644 --- a/api/core/workflow/graph_events/__init__.py +++ b/api/core/workflow/graph_events/__init__.py @@ -38,8 +38,6 @@ from .loop import ( from .node import ( NodeRunExceptionEvent, NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, NodeRunPauseRequestedEvent, NodeRunRetrieverResourceEvent, NodeRunRetryEvent, @@ -62,8 +60,6 @@ __all__ = [ "NodeRunAgentLogEvent", "NodeRunExceptionEvent", "NodeRunFailedEvent", - "NodeRunHumanInputFormFilledEvent", - "NodeRunHumanInputFormTimeoutEvent", "NodeRunIterationFailedEvent", "NodeRunIterationNextEvent", "NodeRunIterationStartedEvent", diff --git a/api/core/workflow/graph_events/graph.py b/api/core/workflow/graph_events/graph.py index f46526bcab..5d10a76c15 100644 --- a/api/core/workflow/graph_events/graph.py +++ b/api/core/workflow/graph_events/graph.py @@ -1,16 +1,11 @@ from pydantic import Field from core.workflow.entities.pause_reason import PauseReason -from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.graph_events import BaseGraphEvent class GraphRunStartedEvent(BaseGraphEvent): - # Reason is emitted for workflow start events and is always set. - reason: WorkflowStartReason = Field( - default=WorkflowStartReason.INITIAL, - description="reason for workflow start", - ) + pass class GraphRunSucceededEvent(BaseGraphEvent): diff --git a/api/core/workflow/graph_events/human_input.py b/api/core/workflow/graph_events/human_input.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/workflow/graph_events/node.py b/api/core/workflow/graph_events/node.py index 975d72ad1f..4d0108e77b 100644 --- a/api/core/workflow/graph_events/node.py +++ b/api/core/workflow/graph_events/node.py @@ -54,22 +54,6 @@ class NodeRunRetryEvent(NodeRunStartedEvent): retry_index: int = Field(..., description="which retry attempt is about to be performed") -class NodeRunHumanInputFormFilledEvent(GraphNodeEventBase): - """Emitted when a HumanInput form is submitted and before the node finishes.""" - - node_title: str = Field(..., description="HumanInput node title") - rendered_content: str = Field(..., description="Markdown content rendered with user inputs.") - action_id: str = Field(..., description="User action identifier chosen in the form.") - action_text: str = Field(..., description="Display text of the chosen action button.") - - -class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase): - """Emitted when a HumanInput form times out.""" - - node_title: str = Field(..., description="HumanInput node title") - expiration_time: datetime = Field(..., description="Form expiration time") - - class NodeRunPauseRequestedEvent(GraphNodeEventBase): reason: PauseReason = Field(..., description="pause reason") diff --git a/api/core/workflow/node_events/__init__.py b/api/core/workflow/node_events/__init__.py index a9bef8f9a2..f14a594c85 100644 --- a/api/core/workflow/node_events/__init__.py +++ b/api/core/workflow/node_events/__init__.py @@ -13,8 +13,6 @@ from .loop import ( LoopSucceededEvent, ) from .node import ( - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, ModelInvokeCompletedEvent, PauseRequestedEvent, RunRetrieverResourceEvent, @@ -25,8 +23,6 @@ from .node import ( __all__ = [ "AgentLogEvent", - "HumanInputFormFilledEvent", - "HumanInputFormTimeoutEvent", "IterationFailedEvent", "IterationNextEvent", "IterationStartedEvent", diff --git a/api/core/workflow/node_events/node.py b/api/core/workflow/node_events/node.py index 9c76b7d7c2..e4fa52f444 100644 --- a/api/core/workflow/node_events/node.py +++ b/api/core/workflow/node_events/node.py @@ -47,19 +47,3 @@ class StreamCompletedEvent(NodeEventBase): class PauseRequestedEvent(NodeEventBase): reason: PauseReason = Field(..., description="pause reason") - - -class HumanInputFormFilledEvent(NodeEventBase): - """Event emitted when a human input form is submitted.""" - - node_title: str - rendered_content: str - action_id: str - action_text: str - - -class HumanInputFormTimeoutEvent(NodeEventBase): - """Event emitted when a human input form times out.""" - - node_title: str - expiration_time: datetime diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 2b773b537c..63e0260341 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -18,8 +18,6 @@ from core.workflow.graph_events import ( GraphNodeEventBase, NodeRunAgentLogEvent, NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, NodeRunIterationFailedEvent, NodeRunIterationNextEvent, NodeRunIterationStartedEvent, @@ -36,8 +34,6 @@ from core.workflow.graph_events import ( ) from core.workflow.node_events import ( AgentLogEvent, - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, IterationFailedEvent, IterationNextEvent, IterationStartedEvent, @@ -65,15 +61,6 @@ logger = logging.getLogger(__name__) class Node(Generic[NodeDataT]): - """BaseNode serves as the foundational class for all node implementations. - - Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output` - attribute to track files generated by the LLM). However, these states are not persisted - when the workflow is suspended or resumed. If a node needs its state to be preserved - across workflow suspension and resumption, it should include the relevant state data - in its output. - """ - node_type: ClassVar[NodeType] execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE _node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData @@ -264,33 +251,10 @@ class Node(Generic[NodeDataT]): return self._node_execution_id def ensure_execution_id(self) -> str: - if self._node_execution_id: - return self._node_execution_id - - resumed_execution_id = self._restore_execution_id_from_runtime_state() - if resumed_execution_id: - self._node_execution_id = resumed_execution_id - return self._node_execution_id - - self._node_execution_id = str(uuid4()) + if not self._node_execution_id: + self._node_execution_id = str(uuid4()) return self._node_execution_id - def _restore_execution_id_from_runtime_state(self) -> str | None: - graph_execution = self.graph_runtime_state.graph_execution - try: - node_executions = graph_execution.node_executions - except AttributeError: - return None - if not isinstance(node_executions, dict): - return None - node_execution = node_executions.get(self._node_id) - if node_execution is None: - return None - execution_id = node_execution.execution_id - if not execution_id: - return None - return str(execution_id) - def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT: return cast(NodeDataT, self._node_data_type.model_validate(data)) @@ -656,28 +620,6 @@ class Node(Generic[NodeDataT]): metadata=event.metadata, ) - @_dispatch.register - def _(self, event: HumanInputFormFilledEvent): - return NodeRunHumanInputFormFilledEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=event.node_title, - rendered_content=event.rendered_content, - action_id=event.action_id, - action_text=event.action_text, - ) - - @_dispatch.register - def _(self, event: HumanInputFormTimeoutEvent): - return NodeRunHumanInputFormTimeoutEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=event.node_title, - expiration_time=event.expiration_time, - ) - @_dispatch.register def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent: return NodeRunLoopStartedEvent( diff --git a/api/core/workflow/nodes/human_input/__init__.py b/api/core/workflow/nodes/human_input/__init__.py index 1789604577..379440557c 100644 --- a/api/core/workflow/nodes/human_input/__init__.py +++ b/api/core/workflow/nodes/human_input/__init__.py @@ -1,3 +1,3 @@ -""" -Human Input node implementation. -""" +from .human_input_node import HumanInputNode + +__all__ = ["HumanInputNode"] diff --git a/api/core/workflow/nodes/human_input/entities.py b/api/core/workflow/nodes/human_input/entities.py index 72d4fc675b..02913d93c3 100644 --- a/api/core/workflow/nodes/human_input/entities.py +++ b/api/core/workflow/nodes/human_input/entities.py @@ -1,350 +1,10 @@ -""" -Human Input node entities. -""" +from pydantic import Field -import re -import uuid -from collections.abc import Mapping, Sequence -from datetime import datetime, timedelta -from typing import Annotated, Any, ClassVar, Literal, Self - -from pydantic import BaseModel, Field, field_validator, model_validator - -from core.variables.consts import SELECTORS_LENGTH from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.runtime import VariablePool - -from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit - -_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") - - -class _WebAppDeliveryConfig(BaseModel): - """Configuration for webapp delivery method.""" - - pass # Empty for webapp delivery - - -class MemberRecipient(BaseModel): - """Member recipient for email delivery.""" - - type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER - user_id: str - - -class ExternalRecipient(BaseModel): - """External recipient for email delivery.""" - - type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL - email: str - - -EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")] - - -class EmailRecipients(BaseModel): - """Email recipients configuration.""" - - # When true, recipients are the union of all workspace members and external items. - # Member items are ignored because they are already covered by the workspace scope. - # De-duplication is applied by email, with member recipients taking precedence. - whole_workspace: bool = False - items: list[EmailRecipient] = Field(default_factory=list) - - -class EmailDeliveryConfig(BaseModel): - """Configuration for email delivery method.""" - - URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}" - - recipients: EmailRecipients - - # the subject of email - subject: str - - # Body is the content of email.It may contain the speical placeholder `{{#url#}}`, which - # represent the url to submit the form. - # - # It may also reference the output variable of the previous node with the syntax - # `{{#.#}}`. - body: str - debug_mode: bool = False - - def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig": - if not user_id: - debug_recipients = EmailRecipients(whole_workspace=False, items=[]) - return self.model_copy(update={"recipients": debug_recipients}) - debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)]) - return self.model_copy(update={"recipients": debug_recipients}) - - @classmethod - def replace_url_placeholder(cls, body: str, url: str | None) -> str: - """Replace the url placeholder with provided value.""" - return body.replace(cls.URL_PLACEHOLDER, url or "") - - @classmethod - def render_body_template( - cls, - *, - body: str, - url: str | None, - variable_pool: VariablePool | None = None, - ) -> str: - """Render email body by replacing placeholders with runtime values.""" - templated_body = cls.replace_url_placeholder(body, url) - if variable_pool is None: - return templated_body - return variable_pool.convert_template(templated_body).text - - -class _DeliveryMethodBase(BaseModel): - """Base delivery method configuration.""" - - enabled: bool = True - id: uuid.UUID = Field(default_factory=uuid.uuid4) - - def extract_variable_selectors(self) -> Sequence[Sequence[str]]: - return () - - -class WebAppDeliveryMethod(_DeliveryMethodBase): - """Webapp delivery method configuration.""" - - type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP - # The config field is not used currently. - config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig) - - -class EmailDeliveryMethod(_DeliveryMethodBase): - """Email delivery method configuration.""" - - type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL - config: EmailDeliveryConfig - - def extract_variable_selectors(self) -> Sequence[Sequence[str]]: - variable_template_parser = VariableTemplateParser(template=self.config.body) - selectors: list[Sequence[str]] = [] - for variable_selector in variable_template_parser.extract_variable_selectors(): - value_selector = list(variable_selector.value_selector) - if len(value_selector) < SELECTORS_LENGTH: - continue - selectors.append(value_selector[:SELECTORS_LENGTH]) - return selectors - - -DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")] - - -def apply_debug_email_recipient( - method: DeliveryChannelConfig, - *, - enabled: bool, - user_id: str, -) -> DeliveryChannelConfig: - if not enabled: - return method - if not isinstance(method, EmailDeliveryMethod): - return method - if not method.config.debug_mode: - return method - debug_config = method.config.with_debug_recipient(user_id or "") - return method.model_copy(update={"config": debug_config}) - - -class FormInputDefault(BaseModel): - """Default configuration for form inputs.""" - - # NOTE: Ideally, a discriminated union would be used to model - # FormInputDefault. However, the UI requires preserving the previous - # value when switching between `VARIABLE` and `CONSTANT` types. This - # necessitates retaining all fields, making a discriminated union unsuitable. - - type: PlaceholderType - - # The selector of default variable, used when `type` is `VARIABLE`. - selector: Sequence[str] = Field(default_factory=tuple) # - - # The value of the default, used when `type` is `CONSTANT`. - # TODO: How should we express JSON values? - value: str = "" - - @model_validator(mode="after") - def _validate_selector(self) -> Self: - if self.type == PlaceholderType.CONSTANT: - return self - if len(self.selector) < SELECTORS_LENGTH: - raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}") - return self - - -class FormInput(BaseModel): - """Form input definition.""" - - type: FormInputType - output_variable_name: str - default: FormInputDefault | None = None - - -_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") - - -class UserAction(BaseModel): - """User action configuration.""" - - # id is the identifier for this action. - # It also serves as the identifiers of output handle. - # - # The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.) - id: str = Field(max_length=20) - title: str = Field(max_length=20) - button_style: ButtonStyle = ButtonStyle.DEFAULT - - @field_validator("id") - @classmethod - def _validate_id(cls, value: str) -> str: - if not _IDENTIFIER_PATTERN.match(value): - raise ValueError( - f"'{value}' is not a valid identifier. It must start with a letter or underscore, " - f"and contain only letters, numbers, or underscores." - ) - return value class HumanInputNodeData(BaseNodeData): - """Human Input node data.""" + """Configuration schema for the HumanInput node.""" - delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list) - form_content: str = "" - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - timeout: int = 36 - timeout_unit: TimeoutUnit = TimeoutUnit.HOUR - - @field_validator("inputs") - @classmethod - def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]: - seen_names: set[str] = set() - for form_input in inputs: - name = form_input.output_variable_name - if name in seen_names: - raise ValueError(f"duplicated output_variable_name '{name}' in inputs") - seen_names.add(name) - return inputs - - @field_validator("user_actions") - @classmethod - def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]: - seen_ids: set[str] = set() - for action in user_actions: - action_id = action.id - if action_id in seen_ids: - raise ValueError(f"duplicated user action id '{action_id}'") - seen_ids.add(action_id) - return user_actions - - def is_webapp_enabled(self) -> bool: - for dm in self.delivery_methods: - if not dm.enabled: - continue - if dm.type == DeliveryMethodType.WEBAPP: - return True - return False - - def expiration_time(self, start_time: datetime) -> datetime: - if self.timeout_unit == TimeoutUnit.HOUR: - return start_time + timedelta(hours=self.timeout) - elif self.timeout_unit == TimeoutUnit.DAY: - return start_time + timedelta(days=self.timeout) - else: - raise AssertionError("unknown timeout unit.") - - def outputs_field_names(self) -> Sequence[str]: - field_names = [] - for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content): - field_names.append(match.group("field_name")) - return field_names - - def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]: - variable_mappings: dict[str, Sequence[str]] = {} - - def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None: - for selector in selectors: - if len(selector) < SELECTORS_LENGTH: - continue - qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#" - variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH]) - - form_template_parser = VariableTemplateParser(template=self.form_content) - _add_variable_selectors( - [selector.value_selector for selector in form_template_parser.extract_variable_selectors()] - ) - for delivery_method in self.delivery_methods: - if not delivery_method.enabled: - continue - _add_variable_selectors(delivery_method.extract_variable_selectors()) - - for input in self.inputs: - default_value = input.default - if default_value is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - default_value_key = ".".join(default_value.selector) - qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#" - variable_mappings[qualified_variable_mapping_key] = default_value.selector - - return variable_mappings - - def find_action_text(self, action_id: str) -> str: - """ - Resolve action display text by id. - """ - for action in self.user_actions: - if action.id == action_id: - return action.title - return action_id - - -class FormDefinition(BaseModel): - form_content: str - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - rendered_content: str - expiration_time: datetime - - # this is used to store the resolved default values - default_values: dict[str, Any] = Field(default_factory=dict) - - # node_title records the title of the HumanInput node. - node_title: str | None = None - - # display_in_ui controls whether the form should be displayed in UI surfaces. - display_in_ui: bool | None = None - - -class HumanInputSubmissionValidationError(ValueError): - pass - - -def validate_human_input_submission( - *, - inputs: Sequence[FormInput], - user_actions: Sequence[UserAction], - selected_action_id: str, - form_data: Mapping[str, Any], -) -> None: - available_actions = {action.id for action in user_actions} - if selected_action_id not in available_actions: - raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}") - - provided_inputs = set(form_data.keys()) - missing_inputs = [ - form_input.output_variable_name - for form_input in inputs - if form_input.output_variable_name not in provided_inputs - ] - - if missing_inputs: - missing_list = ", ".join(missing_inputs) - raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}") + required_variables: list[str] = Field(default_factory=list) + pause_reason: str | None = Field(default=None) diff --git a/api/core/workflow/nodes/human_input/enums.py b/api/core/workflow/nodes/human_input/enums.py deleted file mode 100644 index da85728828..0000000000 --- a/api/core/workflow/nodes/human_input/enums.py +++ /dev/null @@ -1,72 +0,0 @@ -import enum - - -class HumanInputFormStatus(enum.StrEnum): - """Status of a human input form.""" - - # Awaiting submission from any recipient. Forms stay in this state until - # submitted or a timeout rule applies. - WAITING = enum.auto() - # Global timeout reached. The workflow run is stopped and will not resume. - # This is distinct from node-level timeout. - EXPIRED = enum.auto() - # Submitted by a recipient; form data is available and execution resumes - # along the selected action edge. - SUBMITTED = enum.auto() - # Node-level timeout reached. The human input node should emit a timeout - # event and the workflow should resume along the timeout edge. - TIMEOUT = enum.auto() - - -class HumanInputFormKind(enum.StrEnum): - """Kind of a human input form.""" - - RUNTIME = enum.auto() # Form created during workflow execution. - DELIVERY_TEST = enum.auto() # Form created for delivery tests. - - -class DeliveryMethodType(enum.StrEnum): - """Delivery method types for human input forms.""" - - # WEBAPP controls whether the form is delivered to the web app. It not only controls - # the standalone web app, but also controls the installed apps in the console. - WEBAPP = enum.auto() - - EMAIL = enum.auto() - - -class ButtonStyle(enum.StrEnum): - """Button styles for user actions.""" - - PRIMARY = enum.auto() - DEFAULT = enum.auto() - ACCENT = enum.auto() - GHOST = enum.auto() - - -class TimeoutUnit(enum.StrEnum): - """Timeout unit for form expiration.""" - - HOUR = enum.auto() - DAY = enum.auto() - - -class FormInputType(enum.StrEnum): - """Form input types.""" - - TEXT_INPUT = enum.auto() - PARAGRAPH = enum.auto() - - -class PlaceholderType(enum.StrEnum): - """Default value types for form inputs.""" - - VARIABLE = enum.auto() - CONSTANT = enum.auto() - - -class EmailRecipientType(enum.StrEnum): - """Email recipient types.""" - - MEMBER = enum.auto() - EXTERNAL = enum.auto() diff --git a/api/core/workflow/nodes/human_input/human_input_node.py b/api/core/workflow/nodes/human_input/human_input_node.py index 1d7522ea25..6c8bf36fab 100644 --- a/api/core/workflow/nodes/human_input/human_input_node.py +++ b/api/core/workflow/nodes/human_input/human_input_node.py @@ -1,42 +1,12 @@ -import json -import logging -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any +from collections.abc import Mapping +from typing import Any -from core.app.entities.app_invoke_entities import InvokeFrom -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.workflow.entities.pause_reason import HumanInputRequired from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import ( - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, - NodeRunResult, - PauseRequestedEvent, -) -from core.workflow.node_events.base import NodeEventBase -from core.workflow.node_events.node import StreamCompletedEvent +from core.workflow.node_events import NodeRunResult, PauseRequestedEvent from core.workflow.nodes.base.node import Node -from core.workflow.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter -from extensions.ext_database import db -from libs.datetime_utils import naive_utc_now -from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient -from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType - -if TYPE_CHECKING: - from core.workflow.entities.graph_init_params import GraphInitParams - from core.workflow.runtime.graph_runtime_state import GraphRuntimeState - - -_SELECTED_BRANCH_KEY = "selected_branch" - - -logger = logging.getLogger(__name__) +from .entities import HumanInputNodeData class HumanInputNode(Node[HumanInputNodeData]): @@ -47,7 +17,7 @@ class HumanInputNode(Node[HumanInputNodeData]): "edge_source_handle", "edgeSourceHandle", "source_handle", - _SELECTED_BRANCH_KEY, + "selected_branch", "selectedBranch", "branch", "branch_id", @@ -55,37 +25,43 @@ class HumanInputNode(Node[HumanInputNodeData]): "handle", ) - _node_data: HumanInputNodeData - _form_repository: HumanInputFormRepository - _OUTPUT_FIELD_ACTION_ID = "__action_id" - _OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content" - _TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout" - - def __init__( - self, - id: str, - config: Mapping[str, Any], - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - form_repository: HumanInputFormRepository | None = None, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - if form_repository is None: - form_repository = HumanInputFormRepositoryImpl( - session_factory=db.engine, - tenant_id=self.tenant_id, - ) - self._form_repository = form_repository - @classmethod def version(cls) -> str: return "1" + def _run(self): # type: ignore[override] + if self._is_completion_ready(): + branch_handle = self._resolve_branch_selection() + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={}, + edge_source_handle=branch_handle or "source", + ) + + return self._pause_generator() + + def _pause_generator(self): + # TODO(QuantumGhost): yield a real form id. + yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id)) + + def _is_completion_ready(self) -> bool: + """Determine whether all required inputs are satisfied.""" + + if not self.node_data.required_variables: + return False + + variable_pool = self.graph_runtime_state.variable_pool + + for selector_str in self.node_data.required_variables: + parts = selector_str.split(".") + if len(parts) != 2: + return False + segment = variable_pool.get(parts) + if segment is None: + return False + + return True + def _resolve_branch_selection(self) -> str | None: """Determine the branch handle selected by human input if available.""" @@ -132,224 +108,3 @@ class HumanInputNode(Node[HumanInputNodeData]): return candidate return None - - @property - def _workflow_execution_id(self) -> str: - workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id - assert workflow_exec_id is not None - return workflow_exec_id - - def _form_to_pause_event(self, form_entity: HumanInputFormEntity): - required_event = self._human_input_required_event(form_entity) - pause_requested_event = PauseRequestedEvent(reason=required_event) - return pause_requested_event - - def resolve_default_values(self) -> Mapping[str, Any]: - variable_pool = self.graph_runtime_state.variable_pool - resolved_defaults: dict[str, Any] = {} - for input in self._node_data.inputs: - if (default_value := input.default) is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - resolved_value = variable_pool.get(default_value.selector) - if resolved_value is None: - # TODO: How should we handle this? - continue - resolved_defaults[input.output_variable_name] = ( - WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(resolved_value.value) - ) - - return resolved_defaults - - def _should_require_console_recipient(self) -> bool: - if self.invoke_from == InvokeFrom.DEBUGGER: - return True - if self.invoke_from == InvokeFrom.EXPLORE: - return self._node_data.is_webapp_enabled() - return False - - def _display_in_ui(self) -> bool: - if self.invoke_from == InvokeFrom.DEBUGGER: - return True - return self._node_data.is_webapp_enabled() - - def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]: - enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled] - if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}: - enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP] - return [ - apply_debug_email_recipient( - method, - enabled=self.invoke_from == InvokeFrom.DEBUGGER, - user_id=self.user_id or "", - ) - for method in enabled_methods - ] - - def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired: - node_data = self._node_data - resolved_default_values = self.resolve_default_values() - display_in_ui = self._display_in_ui() - form_token = form_entity.web_app_token - if display_in_ui and form_token is None: - raise AssertionError("Form token should be available for UI execution.") - return HumanInputRequired( - form_id=form_entity.id, - form_content=form_entity.rendered_content, - inputs=node_data.inputs, - actions=node_data.user_actions, - display_in_ui=display_in_ui, - node_id=self.id, - node_title=node_data.title, - form_token=form_token, - resolved_default_values=resolved_default_values, - ) - - def _run(self) -> Generator[NodeEventBase, None, None]: - """ - Execute the human input node. - - This method will: - 1. Generate a unique form ID - 2. Create form content with variable substitution - 3. Create form in database - 4. Send form via configured delivery methods - 5. Suspend workflow execution - 6. Wait for form submission to resume - """ - repo = self._form_repository - form = repo.get_form(self._workflow_execution_id, self.id) - if form is None: - display_in_ui = self._display_in_ui() - params = FormCreateParams( - app_id=self.app_id, - workflow_execution_id=self._workflow_execution_id, - node_id=self.id, - form_config=self._node_data, - rendered_content=self.render_form_content_before_submission(), - delivery_methods=self._effective_delivery_methods(), - display_in_ui=display_in_ui, - resolved_default_values=self.resolve_default_values(), - console_recipient_required=self._should_require_console_recipient(), - console_creator_account_id=( - self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None - ), - backstage_recipient_required=True, - ) - form_entity = self._form_repository.create_form(params) - # Create human input required event - - logger.info( - "Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s", - self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id, - self.id, - form_entity.id, - ) - yield self._form_to_pause_event(form_entity) - return - - if ( - form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED} - or form.expiration_time <= naive_utc_now() - ): - yield HumanInputFormTimeoutEvent( - node_title=self._node_data.title, - expiration_time=form.expiration_time, - ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={self._OUTPUT_FIELD_ACTION_ID: ""}, - edge_source_handle=self._TIMEOUT_HANDLE, - ) - ) - return - - if not form.submitted: - yield self._form_to_pause_event(form) - return - - selected_action_id = form.selected_action_id - if selected_action_id is None: - raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}") - submitted_data = form.submitted_data or {} - outputs: dict[str, Any] = dict(submitted_data) - outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id - rendered_content = self.render_form_content_with_outputs( - form.rendered_content, - outputs, - self._node_data.outputs_field_names(), - ) - outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content - - action_text = self._node_data.find_action_text(selected_action_id) - - yield HumanInputFormFilledEvent( - node_title=self._node_data.title, - rendered_content=rendered_content, - action_id=selected_action_id, - action_text=action_text, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs, - edge_source_handle=selected_action_id, - ) - ) - - def render_form_content_before_submission(self) -> str: - """ - Process form content by substituting variables. - - This method should: - 1. Parse the form_content markdown - 2. Substitute {{#node_name.var_name#}} with actual values - 3. Keep {{#$output.field_name#}} placeholders for form inputs - """ - rendered_form_content = self.graph_runtime_state.variable_pool.convert_template( - self._node_data.form_content, - ) - return rendered_form_content.markdown - - @staticmethod - def render_form_content_with_outputs( - form_content: str, - outputs: Mapping[str, Any], - field_names: Sequence[str], - ) -> str: - """ - Replace {{#$output.xxx#}} placeholders with submitted values. - """ - rendered_content = form_content - for field_name in field_names: - placeholder = "{{#$output." + field_name + "#}}" - value = outputs.get(field_name) - if value is None: - replacement = "" - elif isinstance(value, (dict, list)): - replacement = json.dumps(value, ensure_ascii=False) - else: - replacement = str(value) - rendered_content = rendered_content.replace(placeholder, replacement) - return rendered_content - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: Mapping[str, Any], - ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selectors referenced in form content and input default values. - - This method should parse: - 1. Variables referenced in form_content ({{#node_name.var_name#}}) - 2. Variables referenced in input default values - """ - validated_node_data = HumanInputNodeData.model_validate(node_data) - return validated_node_data.extract_variable_selector_to_variable_mapping(node_id) diff --git a/api/core/workflow/repositories/human_input_form_repository.py b/api/core/workflow/repositories/human_input_form_repository.py deleted file mode 100644 index efde59c6fd..0000000000 --- a/api/core/workflow/repositories/human_input_form_repository.py +++ /dev/null @@ -1,152 +0,0 @@ -import abc -import dataclasses -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any, Protocol - -from core.workflow.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus - - -class HumanInputError(Exception): - pass - - -class FormNotFoundError(HumanInputError): - pass - - -@dataclasses.dataclass -class FormCreateParams: - # app_id is the identifier for the app that the form belongs to. - # It is a string with uuid format. - app_id: str - # None when creating a delivery test form; set for runtime forms. - workflow_execution_id: str | None - - # node_id is the identifier for a specific - # node in the graph. - # - # TODO: for node inside loop / iteration, this would - # cause problems, as a single node may be executed multiple times. - node_id: str - - form_config: HumanInputNodeData - rendered_content: str - # Delivery methods already filtered by runtime context (invoke_from). - delivery_methods: Sequence[DeliveryChannelConfig] - # UI display flag computed by runtime context. - display_in_ui: bool - - # resolved_default_values saves the values for defaults with - # type = VARIABLE. - # - # For type = CONSTANT, the value is not stored inside `resolved_default_values` - resolved_default_values: Mapping[str, Any] - form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME - - # Force creating a console-only recipient for submission in Console. - console_recipient_required: bool = False - console_creator_account_id: str | None = None - # Force creating a backstage recipient for submission in Console. - backstage_recipient_required: bool = False - - -class HumanInputFormEntity(abc.ABC): - @property - @abc.abstractmethod - def id(self) -> str: - """id returns the identifer of the form.""" - pass - - @property - @abc.abstractmethod - def web_app_token(self) -> str | None: - """web_app_token returns the token for submission inside webapp. - - For console/debug execution, this may point to the console submission token - if the form is configured to require console delivery. - """ - - # TODO: what if the users are allowed to add multiple - # webapp delivery? - pass - - @property - @abc.abstractmethod - def recipients(self) -> list["HumanInputFormRecipientEntity"]: ... - - @property - @abc.abstractmethod - def rendered_content(self) -> str: - """Rendered markdown content associated with the form.""" - ... - - @property - @abc.abstractmethod - def selected_action_id(self) -> str | None: - """Identifier of the selected user action if the form has been submitted.""" - ... - - @property - @abc.abstractmethod - def submitted_data(self) -> Mapping[str, Any] | None: - """Submitted form data if available.""" - ... - - @property - @abc.abstractmethod - def submitted(self) -> bool: - """Whether the form has been submitted.""" - ... - - @property - @abc.abstractmethod - def status(self) -> HumanInputFormStatus: - """Current status of the form.""" - ... - - @property - @abc.abstractmethod - def expiration_time(self) -> datetime: - """When the form expires.""" - ... - - -class HumanInputFormRecipientEntity(abc.ABC): - @property - @abc.abstractmethod - def id(self) -> str: - """id returns the identifer of this recipient.""" - ... - - @property - @abc.abstractmethod - def token(self) -> str: - """token returns a random string used to submit form""" - ... - - -class HumanInputFormRepository(Protocol): - """ - Repository interface for HumanInputForm. - - This interface defines the contract for accessing and manipulating - HumanInputForm data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and other implementation details should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - """Get the form created for a given human input node in a workflow execution. Returns - `None` if the form has not been created yet.""" - ... - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - """ - Create a human input form from form definition. - """ - ... diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py index f79230217c..401cecc162 100644 --- a/api/core/workflow/runtime/graph_runtime_state.py +++ b/api/core/workflow/runtime/graph_runtime_state.py @@ -6,18 +6,14 @@ import threading from collections.abc import Mapping, Sequence from copy import deepcopy from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Protocol +from typing import Any, Protocol -from pydantic import BaseModel, Field from pydantic.json import pydantic_encoder from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.enums import NodeState +from core.workflow.entities.pause_reason import PauseReason from core.workflow.runtime.variable_pool import VariablePool -if TYPE_CHECKING: - from core.workflow.entities.pause_reason import PauseReason - class ReadyQueueProtocol(Protocol): """Structural interface required from ready queue implementations.""" @@ -64,7 +60,7 @@ class GraphExecutionProtocol(Protocol): aborted: bool error: Exception | None exceptions_count: int - pause_reasons: Sequence[PauseReason] + pause_reasons: list[PauseReason] def start(self) -> None: """Transition execution into the running state.""" @@ -107,33 +103,14 @@ class ResponseStreamCoordinatorProtocol(Protocol): ... -class NodeProtocol(Protocol): - """Structural interface for graph nodes.""" - - id: str - state: NodeState - - -class EdgeProtocol(Protocol): - id: str - state: NodeState - - class GraphProtocol(Protocol): """Structural interface required from graph instances attached to the runtime state.""" - nodes: Mapping[str, NodeProtocol] - edges: Mapping[str, EdgeProtocol] - root_node: NodeProtocol + nodes: Mapping[str, object] + edges: Mapping[str, object] + root_node: object - def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ... - - -class _GraphStateSnapshot(BaseModel): - """Serializable graph state snapshot for node/edge states.""" - - nodes: dict[str, NodeState] = Field(default_factory=dict) - edges: dict[str, NodeState] = Field(default_factory=dict) + def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ... @dataclass(slots=True) @@ -151,20 +128,10 @@ class _GraphRuntimeStateSnapshot: graph_execution_dump: str | None response_coordinator_dump: str | None paused_nodes: tuple[str, ...] - deferred_nodes: tuple[str, ...] - graph_node_states: dict[str, NodeState] - graph_edge_states: dict[str, NodeState] class GraphRuntimeState: - """Mutable runtime state shared across graph execution components. - - `GraphRuntimeState` encapsulates the runtime state of workflow execution, - including scheduling details, variable values, and timing information. - - Values that are initialized prior to workflow execution and remain constant - throughout the execution should be part of `GraphInitParams` instead. - """ + """Mutable runtime state shared across graph execution components.""" def __init__( self, @@ -202,16 +169,6 @@ class GraphRuntimeState: self._pending_response_coordinator_dump: str | None = None self._pending_graph_execution_workflow_id: str | None = None self._paused_nodes: set[str] = set() - self._deferred_nodes: set[str] = set() - - # Node and edges states needed to be restored into - # graph object. - # - # These two fields are non-None only when resuming from a snapshot. - # Once the graph is attached, these two fields will be set to None. - self._pending_graph_node_states: dict[str, NodeState] | None = None - self._pending_graph_edge_states: dict[str, NodeState] | None = None - self.stop_event: threading.Event = threading.Event() if graph is not None: @@ -233,7 +190,6 @@ class GraphRuntimeState: if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None: self._response_coordinator.loads(self._pending_response_coordinator_dump) self._pending_response_coordinator_dump = None - self._apply_pending_graph_state() def configure(self, *, graph: GraphProtocol | None = None) -> None: """Ensure core collaborators are initialized with the provided context.""" @@ -355,13 +311,8 @@ class GraphRuntimeState: "ready_queue": self.ready_queue.dumps(), "graph_execution": self.graph_execution.dumps(), "paused_nodes": list(self._paused_nodes), - "deferred_nodes": list(self._deferred_nodes), } - graph_state = self._snapshot_graph_state() - if graph_state is not None: - snapshot["graph_state"] = graph_state - if self._response_coordinator is not None and self._graph is not None: snapshot["response_coordinator"] = self._response_coordinator.dumps() @@ -395,11 +346,6 @@ class GraphRuntimeState: self._paused_nodes.add(node_id) - def get_paused_nodes(self) -> list[str]: - """Retrieve the list of paused nodes without mutating internal state.""" - - return list(self._paused_nodes) - def consume_paused_nodes(self) -> list[str]: """Retrieve and clear the list of paused nodes awaiting resume.""" @@ -407,23 +353,6 @@ class GraphRuntimeState: self._paused_nodes.clear() return nodes - def register_deferred_node(self, node_id: str) -> None: - """Record a node that became ready during pause and should resume later.""" - - self._deferred_nodes.add(node_id) - - def get_deferred_nodes(self) -> list[str]: - """Retrieve deferred nodes without mutating internal state.""" - - return list(self._deferred_nodes) - - def consume_deferred_nodes(self) -> list[str]: - """Retrieve and clear deferred nodes awaiting resume.""" - - nodes = list(self._deferred_nodes) - self._deferred_nodes.clear() - return nodes - # ------------------------------------------------------------------ # Builders # ------------------------------------------------------------------ @@ -485,10 +414,6 @@ class GraphRuntimeState: graph_execution_payload = payload.get("graph_execution") response_payload = payload.get("response_coordinator") paused_nodes_payload = payload.get("paused_nodes", []) - deferred_nodes_payload = payload.get("deferred_nodes", []) - graph_state_payload = payload.get("graph_state", {}) or {} - graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes") - graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges") return _GraphRuntimeStateSnapshot( start_at=start_at, @@ -502,9 +427,6 @@ class GraphRuntimeState: graph_execution_dump=graph_execution_payload, response_coordinator_dump=response_payload, paused_nodes=tuple(map(str, paused_nodes_payload)), - deferred_nodes=tuple(map(str, deferred_nodes_payload)), - graph_node_states=graph_node_states, - graph_edge_states=graph_edge_states, ) def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None: @@ -520,10 +442,6 @@ class GraphRuntimeState: self._restore_graph_execution(snapshot.graph_execution_dump) self._restore_response_coordinator(snapshot.response_coordinator_dump) self._paused_nodes = set(snapshot.paused_nodes) - self._deferred_nodes = set(snapshot.deferred_nodes) - self._pending_graph_node_states = snapshot.graph_node_states or None - self._pending_graph_edge_states = snapshot.graph_edge_states or None - self._apply_pending_graph_state() def _restore_ready_queue(self, payload: str | None) -> None: if payload is not None: @@ -560,68 +478,3 @@ class GraphRuntimeState: self._pending_response_coordinator_dump = payload self._response_coordinator = None - - def _snapshot_graph_state(self) -> _GraphStateSnapshot: - graph = self._graph - if graph is None: - if self._pending_graph_node_states is None and self._pending_graph_edge_states is None: - return _GraphStateSnapshot() - return _GraphStateSnapshot( - nodes=self._pending_graph_node_states or {}, - edges=self._pending_graph_edge_states or {}, - ) - - nodes = graph.nodes - edges = graph.edges - if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping): - return _GraphStateSnapshot() - - node_states = {} - for node_id, node in nodes.items(): - if not isinstance(node_id, str): - continue - node_states[node_id] = node.state - - edge_states = {} - for edge_id, edge in edges.items(): - if not isinstance(edge_id, str): - continue - edge_states[edge_id] = edge.state - - return _GraphStateSnapshot(nodes=node_states, edges=edge_states) - - def _apply_pending_graph_state(self) -> None: - if self._graph is None: - return - if self._pending_graph_node_states: - for node_id, state in self._pending_graph_node_states.items(): - node = self._graph.nodes.get(node_id) - if node is None: - continue - node.state = state - if self._pending_graph_edge_states: - for edge_id, state in self._pending_graph_edge_states.items(): - edge = self._graph.edges.get(edge_id) - if edge is None: - continue - edge.state = state - - self._pending_graph_node_states = None - self._pending_graph_edge_states = None - - -def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]: - if not isinstance(payload, Mapping): - return {} - raw_map = payload.get(key, {}) - if not isinstance(raw_map, Mapping): - return {} - result: dict[str, NodeState] = {} - for node_id, raw_state in raw_map.items(): - if not isinstance(node_id, str): - continue - try: - result[node_id] = NodeState(str(raw_state)) - except ValueError: - continue - return result diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py index f1f549e1f8..5456043ccd 100644 --- a/api/core/workflow/workflow_type_encoder.py +++ b/api/core/workflow/workflow_type_encoder.py @@ -15,14 +15,12 @@ class WorkflowRuntimeTypeConverter: def to_json_encodable(self, value: None) -> None: ... def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: - """Convert runtime values to JSON-serializable structures.""" - - result = self.value_to_json_encodable_recursive(value) + result = self._to_json_encodable_recursive(value) if isinstance(result, Mapping) or result is None: return result return {} - def value_to_json_encodable_recursive(self, value: Any): + def _to_json_encodable_recursive(self, value: Any): if value is None: return value if isinstance(value, (bool, int, str, float)): @@ -31,7 +29,7 @@ class WorkflowRuntimeTypeConverter: # Convert Decimal to float for JSON serialization return float(value) if isinstance(value, Segment): - return self.value_to_json_encodable_recursive(value.value) + return self._to_json_encodable_recursive(value.value) if isinstance(value, File): return value.to_dict() if isinstance(value, BaseModel): @@ -39,11 +37,11 @@ class WorkflowRuntimeTypeConverter: if isinstance(value, dict): res = {} for k, v in value.items(): - res[k] = self.value_to_json_encodable_recursive(v) + res[k] = self._to_json_encodable_recursive(v) return res if isinstance(value, list): res_list = [] for item in value: - res_list.append(self.value_to_json_encodable_recursive(item)) + res_list.append(self._to_json_encodable_recursive(item)) return res_list return value diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 03e6cbda68..c0279f893b 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then if [[ -z "${CELERY_QUEUES}" ]]; then if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" else # Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" fi else DEFAULT_QUEUES="${CELERY_QUEUES}" @@ -102,7 +102,7 @@ elif [[ "${MODE}" == "job" ]]; then fi echo "Running Flask job command: flask $*" - + # Temporarily disable exit on error to capture exit code set +e flask "$@" diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index aa9723f375..af983f6d87 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -151,12 +151,6 @@ def init_app(app: DifyApp) -> Celery: "task": "schedule.queue_monitor_task.queue_monitor_task", "schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30), } - if dify_config.ENABLE_HUMAN_INPUT_TIMEOUT_TASK: - imports.append("tasks.human_input_timeout_tasks") - beat_schedule["human_input_form_timeout"] = { - "task": "human_input_form_timeout.check_and_resume", - "schedule": timedelta(minutes=dify_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL), - } if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED: imports.append("schedule.check_upgradable_plugin_task") imports.append("tasks.process_tenant_plugin_autoupgrade_check_task") diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 0797a3cb98..5e75bc36b0 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -8,16 +8,12 @@ from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union import redis from redis import RedisError from redis.cache import CacheConfig -from redis.client import PubSub from redis.cluster import ClusterNode, RedisCluster from redis.connection import Connection, SSLConnection from redis.sentinel import Sentinel from configs import dify_config from dify_app import DifyApp -from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol -from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel -from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel if TYPE_CHECKING: from redis.lock import Lock @@ -110,7 +106,6 @@ class RedisClientWrapper: def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ... def zcard(self, name: str | bytes) -> Any: ... def getdel(self, name: str | bytes) -> Any: ... - def pubsub(self) -> PubSub: ... def __getattr__(self, item: str) -> Any: if self._client is None: @@ -119,7 +114,6 @@ class RedisClientWrapper: redis_client: RedisClientWrapper = RedisClientWrapper() -pubsub_redis_client: RedisClientWrapper = RedisClientWrapper() def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]: @@ -232,12 +226,6 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis return client -def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> Union[redis.Redis, RedisCluster]: - if use_clusters: - return RedisCluster.from_url(pubsub_url) - return redis.Redis.from_url(pubsub_url) - - def init_app(app: DifyApp): """Initialize Redis client and attach it to the app.""" global redis_client @@ -256,24 +244,6 @@ def init_app(app: DifyApp): redis_client.initialize(client) app.extensions["redis"] = redis_client - pubsub_client = client - if dify_config.normalized_pubsub_redis_url: - pubsub_client = _create_pubsub_client( - dify_config.normalized_pubsub_redis_url, dify_config.PUBSUB_REDIS_USE_CLUSTERS - ) - pubsub_redis_client.initialize(pubsub_client) - - -def get_pubsub_redis_client() -> RedisClientWrapper: - return pubsub_redis_client - - -def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol: - redis_conn = get_pubsub_redis_client() - if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded": - return ShardedRedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType] - return RedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType] - P = ParamSpec("P") R = TypeVar("R") diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index 817c8b0448..f67723630b 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -13,7 +13,6 @@ from typing import Any from sqlalchemy.orm import sessionmaker -from core.workflow.enums import WorkflowNodeExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value @@ -208,10 +207,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep reverse=True, ) - for row in deduplicated_results: - model = _dict_to_workflow_node_execution_model(row) - if model.status != WorkflowNodeExecutionStatus.PAUSED: - return model + if deduplicated_results: + return _dict_to_workflow_node_execution_model(deduplicated_results[0]) return None @@ -312,8 +309,6 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep if model and model.id: # Ensure model is valid models.append(model) - models = [model for model in models if model.status != WorkflowNodeExecutionStatus.PAUSED] - # Sort by index DESC for trace visualization models.sort(key=lambda x: x.index, reverse=True) diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index cda46f2339..d8ae0ad8b8 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -192,7 +192,6 @@ class StatusCount(ResponseModel): success: int failed: int partial_success: int - paused: int class ModelConfig(ResponseModel): diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 77b26a7423..e6c3b42f93 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -6,7 +6,6 @@ from uuid import uuid4 from pydantic import BaseModel, ConfigDict, Field, field_validator -from core.entities.execution_extra_content import ExecutionExtraContentDomainModel from core.file import File from fields.conversation_fields import AgentThought, JSONValue, MessageFile @@ -62,7 +61,6 @@ class MessageListItem(ResponseModel): message_files: list[MessageFile] status: str error: str | None = None - extra_contents: list[ExecutionExtraContentDomainModel] @field_validator("inputs", mode="before") @classmethod diff --git a/api/libs/broadcast_channel/redis/_subscription.py b/api/libs/broadcast_channel/redis/_subscription.py index fa2be421a1..7d4b8e63ca 100644 --- a/api/libs/broadcast_channel/redis/_subscription.py +++ b/api/libs/broadcast_channel/redis/_subscription.py @@ -162,7 +162,7 @@ class RedisSubscriptionBase(Subscription): self._start_if_needed() return iter(self._message_iterator()) - def receive(self, timeout: float | None = 0.1) -> bytes | None: + def receive(self, timeout: float | None = None) -> bytes | None: """Receive the next message from the subscription.""" if self._closed.is_set(): raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed") diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py index 9e8ab90e8e..d190c51bbc 100644 --- a/api/libs/broadcast_channel/redis/sharded_channel.py +++ b/api/libs/broadcast_channel/redis/sharded_channel.py @@ -61,14 +61,7 @@ class _RedisShardedSubscription(RedisSubscriptionBase): def _get_message(self) -> dict | None: assert self._pubsub is not None - # NOTE(QuantumGhost): this is an issue in - # upstream code. If Sharded PubSub is used with Cluster, the - # `ClusterPubSub.get_sharded_message` will return `None` regardless of - # message['type']. - # - # Since we have already filtered at the caller's site, we can safely set - # `ignore_subscribe_messages=False`. - return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=0.1) # type: ignore[attr-defined] + return self._pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=0.1) # type: ignore[attr-defined] def _get_message_type(self) -> str: return "smessage" diff --git a/api/libs/email_template_renderer.py b/api/libs/email_template_renderer.py deleted file mode 100644 index 98ea30ab46..0000000000 --- a/api/libs/email_template_renderer.py +++ /dev/null @@ -1,49 +0,0 @@ -""" -Email template rendering helpers with configurable safety modes. -""" - -import time -from collections.abc import Mapping -from typing import Any - -from flask import render_template_string -from jinja2.runtime import Context -from jinja2.sandbox import ImmutableSandboxedEnvironment - -from configs import dify_config -from configs.feature import TemplateMode - - -class SandboxedEnvironment(ImmutableSandboxedEnvironment): - """Sandboxed environment with execution timeout.""" - - def __init__(self, timeout: int, *args: Any, **kwargs: Any): - self._deadline = time.time() + timeout if timeout else None - super().__init__(*args, **kwargs) - - def call(self, context: Context, obj: Any, *args: Any, **kwargs: Any) -> Any: - if self._deadline is not None and time.time() > self._deadline: - raise TimeoutError("Template rendering timeout") - return super().call(context, obj, *args, **kwargs) - - -def render_email_template(template: str, substitutions: Mapping[str, str]) -> str: - """ - Render email template content according to the configured template mode. - - In unsafe mode, Jinja expressions are evaluated directly. - In sandbox mode, a sandboxed environment with timeout is used. - In disabled mode, the template is returned without rendering. - """ - mode = dify_config.MAIL_TEMPLATING_MODE - timeout = dify_config.MAIL_TEMPLATING_TIMEOUT - - if mode == TemplateMode.UNSAFE: - return render_template_string(template, **substitutions) - if mode == TemplateMode.SANDBOX: - env = SandboxedEnvironment(timeout=timeout) - tmpl = env.from_string(template) - return tmpl.render(substitutions) - if mode == TemplateMode.DISABLED: - return template - raise ValueError(f"Unsupported mail templating mode: {mode}") diff --git a/api/libs/flask_utils.py b/api/libs/flask_utils.py index e45c8fe319..beade7eb25 100644 --- a/api/libs/flask_utils.py +++ b/api/libs/flask_utils.py @@ -1,15 +1,12 @@ import contextvars from collections.abc import Iterator from contextlib import contextmanager -from typing import TYPE_CHECKING, TypeVar +from typing import TypeVar from flask import Flask, g T = TypeVar("T") -if TYPE_CHECKING: - from models import Account, EndUser - @contextmanager def preserve_flask_contexts( @@ -67,7 +64,3 @@ def preserve_flask_contexts( finally: # Any cleanup can be added here if needed pass - - -def set_login_user(user: "Account | EndUser"): - g._login_user = user diff --git a/api/libs/helper.py b/api/libs/helper.py index fb577b9c99..07c4823727 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -7,10 +7,10 @@ import struct import subprocess import time import uuid -from collections.abc import Callable, Generator, Mapping +from collections.abc import Generator, Mapping from datetime import datetime from hashlib import sha256 -from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast +from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast from uuid import UUID from zoneinfo import available_timezones @@ -126,13 +126,6 @@ class TimestampField(fields.Raw): return int(value.timestamp()) -class OptionalTimestampField(fields.Raw): - def format(self, value) -> int | None: - if value is None: - return None - return int(value.timestamp()) - - def email(email): # Define a regex pattern for email addresses pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$" @@ -244,26 +237,6 @@ def convert_datetime_to_date(field, target_timezone: str = ":tz"): def generate_string(n): - """ - Generates a cryptographically secure random string of the specified length. - - This function uses a cryptographically secure pseudorandom number generator (CSPRNG) - to create a string composed of ASCII letters (both uppercase and lowercase) and digits. - - Each character in the generated string provides approximately 5.95 bits of entropy - (log2(62)). To ensure a minimum of 128 bits of entropy for security purposes, the - length of the string (`n`) should be at least 22 characters. - - Args: - n (int): The length of the random string to generate. For secure usage, - `n` should be 22 or greater. - - Returns: - str: A random string of length `n` composed of ASCII letters and digits. - - Note: - This function is suitable for generating credentials or other secure tokens. - """ letters_digits = string.ascii_letters + string.digits result = "" for _ in range(n): @@ -432,35 +405,11 @@ class TokenManager: return f"{token_type}:account:{account_id}" -class _RateLimiterRedisClient(Protocol): - def zadd(self, name: str | bytes, mapping: dict[str | bytes | int | float, float | int | str | bytes]) -> int: ... - - def zremrangebyscore(self, name: str | bytes, min: str | float, max: str | float) -> int: ... - - def zcard(self, name: str | bytes) -> int: ... - - def expire(self, name: str | bytes, time: int) -> bool: ... - - -def _default_rate_limit_member_factory() -> str: - current_time = int(time.time()) - return f"{current_time}:{secrets.token_urlsafe(nbytes=8)}" - - class RateLimiter: - def __init__( - self, - prefix: str, - max_attempts: int, - time_window: int, - member_factory: Callable[[], str] = _default_rate_limit_member_factory, - redis_client: _RateLimiterRedisClient = redis_client, - ): + def __init__(self, prefix: str, max_attempts: int, time_window: int): self.prefix = prefix self.max_attempts = max_attempts self.time_window = time_window - self._member_factory = member_factory - self._redis_client = redis_client def _get_key(self, email: str) -> str: return f"{self.prefix}:{email}" @@ -470,8 +419,8 @@ class RateLimiter: current_time = int(time.time()) window_start_time = current_time - self.time_window - self._redis_client.zremrangebyscore(key, "-inf", window_start_time) - attempts = self._redis_client.zcard(key) + redis_client.zremrangebyscore(key, "-inf", window_start_time) + attempts = redis_client.zcard(key) if attempts and int(attempts) >= self.max_attempts: return True @@ -479,8 +428,7 @@ class RateLimiter: def increment_rate_limit(self, email: str): key = self._get_key(email) - member = self._member_factory() current_time = int(time.time()) - self._redis_client.zadd(key, {member: current_time}) - self._redis_client.expire(key, self.time_window * 2) + redis_client.zadd(key, {current_time: current_time}) + redis_client.expire(key, self.time_window * 2) diff --git a/api/migrations/versions/2026_01_29_1415-e8c3b3c46151_add_human_input_related_db_models.py b/api/migrations/versions/2026_01_29_1415-e8c3b3c46151_add_human_input_related_db_models.py deleted file mode 100644 index a1546ef940..0000000000 --- a/api/migrations/versions/2026_01_29_1415-e8c3b3c46151_add_human_input_related_db_models.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Add human input related db models - -Revision ID: e8c3b3c46151 -Revises: 788d3099ae3a -Create Date: 2026-01-29 14:15:23.081903 - -""" - -from alembic import op -import models as models -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = "e8c3b3c46151" -down_revision = "788d3099ae3a" -branch_labels = None -depends_on = None - - -def upgrade(): - op.create_table( - "execution_extra_contents", - sa.Column("id", models.types.StringUUID(), nullable=False), - sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), - sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), - - sa.Column("type", sa.String(length=30), nullable=False), - sa.Column("workflow_run_id", models.types.StringUUID(), nullable=False), - sa.Column("message_id", models.types.StringUUID(), nullable=True), - sa.Column("form_id", models.types.StringUUID(), nullable=True), - sa.PrimaryKeyConstraint("id", name=op.f("execution_extra_contents_pkey")), - ) - with op.batch_alter_table("execution_extra_contents", schema=None) as batch_op: - batch_op.create_index(batch_op.f("execution_extra_contents_message_id_idx"), ["message_id"], unique=False) - batch_op.create_index( - batch_op.f("execution_extra_contents_workflow_run_id_idx"), ["workflow_run_id"], unique=False - ) - - op.create_table( - "human_input_form_deliveries", - sa.Column("id", models.types.StringUUID(), nullable=False), - sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), - sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), - - sa.Column("form_id", models.types.StringUUID(), nullable=False), - sa.Column("delivery_method_type", sa.String(length=20), nullable=False), - sa.Column("delivery_config_id", models.types.StringUUID(), nullable=True), - sa.Column("channel_payload", sa.Text(), nullable=False), - sa.PrimaryKeyConstraint("id", name=op.f("human_input_form_deliveries_pkey")), - ) - - op.create_table( - "human_input_form_recipients", - sa.Column("id", models.types.StringUUID(), nullable=False), - sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), - sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), - - sa.Column("form_id", models.types.StringUUID(), nullable=False), - sa.Column("delivery_id", models.types.StringUUID(), nullable=False), - sa.Column("recipient_type", sa.String(length=20), nullable=False), - sa.Column("recipient_payload", sa.Text(), nullable=False), - sa.Column("access_token", sa.VARCHAR(length=32), nullable=False), - sa.PrimaryKeyConstraint("id", name=op.f("human_input_form_recipients_pkey")), - ) - with op.batch_alter_table('human_input_form_recipients', schema=None) as batch_op: - batch_op.create_unique_constraint(batch_op.f('human_input_form_recipients_access_token_key'), ['access_token']) - - op.create_table( - "human_input_forms", - sa.Column("id", models.types.StringUUID(), nullable=False), - sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), - sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), - - sa.Column("tenant_id", models.types.StringUUID(), nullable=False), - sa.Column("app_id", models.types.StringUUID(), nullable=False), - sa.Column("workflow_run_id", models.types.StringUUID(), nullable=True), - sa.Column("form_kind", sa.String(length=20), nullable=False), - sa.Column("node_id", sa.String(length=60), nullable=False), - sa.Column("form_definition", sa.Text(), nullable=False), - sa.Column("rendered_content", sa.Text(), nullable=False), - sa.Column("status", sa.String(length=20), nullable=False), - sa.Column("expiration_time", sa.DateTime(), nullable=False), - sa.Column("selected_action_id", sa.String(length=200), nullable=True), - sa.Column("submitted_data", sa.Text(), nullable=True), - sa.Column("submitted_at", sa.DateTime(), nullable=True), - sa.Column("submission_user_id", models.types.StringUUID(), nullable=True), - sa.Column("submission_end_user_id", models.types.StringUUID(), nullable=True), - sa.Column("completed_by_recipient_id", models.types.StringUUID(), nullable=True), - - sa.PrimaryKeyConstraint("id", name=op.f("human_input_forms_pkey")), - ) - - -def downgrade(): - op.drop_table("human_input_forms") - op.drop_table("human_input_form_recipients") - op.drop_table("human_input_form_deliveries") - op.drop_table("execution_extra_contents") diff --git a/api/models/__init__.py b/api/models/__init__.py index 1d5d604ba7..74b33130ef 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -34,8 +34,6 @@ from .enums import ( WorkflowRunTriggeredFrom, WorkflowTriggerStatus, ) -from .execution_extra_content import ExecutionExtraContent, HumanInputContent -from .human_input import HumanInputForm from .model import ( AccountTrialAppRecord, ApiRequest, @@ -157,12 +155,9 @@ __all__ = [ "DocumentSegment", "Embedding", "EndUser", - "ExecutionExtraContent", "ExporleBanner", "ExternalKnowledgeApis", "ExternalKnowledgeBindings", - "HumanInputContent", - "HumanInputForm", "IconType", "InstalledApp", "InvitationCode", diff --git a/api/models/base.py b/api/models/base.py index aa93d31199..c8a5e20f25 100644 --- a/api/models/base.py +++ b/api/models/base.py @@ -41,7 +41,7 @@ class DefaultFieldsMixin: ) updated_at: Mapped[datetime] = mapped_column( - DateTime, + __name_pos=DateTime, nullable=False, default=naive_utc_now, server_default=func.current_timestamp(), diff --git a/api/models/enums.py b/api/models/enums.py index 2bc61120ce..8cd3d4cf2a 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -36,7 +36,6 @@ class MessageStatus(StrEnum): """ NORMAL = "normal" - PAUSED = "paused" ERROR = "error" diff --git a/api/models/execution_extra_content.py b/api/models/execution_extra_content.py deleted file mode 100644 index d0bd34efec..0000000000 --- a/api/models/execution_extra_content.py +++ /dev/null @@ -1,78 +0,0 @@ -from enum import StrEnum, auto -from typing import TYPE_CHECKING - -from sqlalchemy.orm import Mapped, mapped_column, relationship - -from .base import Base, DefaultFieldsMixin -from .types import EnumText, StringUUID - -if TYPE_CHECKING: - from .human_input import HumanInputForm - - -class ExecutionContentType(StrEnum): - HUMAN_INPUT = auto() - - -class ExecutionExtraContent(DefaultFieldsMixin, Base): - """ExecutionExtraContent stores extra contents produced during workflow / chatflow execution.""" - - # The `ExecutionExtraContent` uses single table inheritance to model different - # kinds of contents produced during message generation. - # - # See: https://docs.sqlalchemy.org/en/20/orm/inheritance.html#single-table-inheritance - - __tablename__ = "execution_extra_contents" - __mapper_args__ = { - "polymorphic_abstract": True, - "polymorphic_on": "type", - "with_polymorphic": "*", - } - # type records the type of the content. It serves as the `discriminator` for the - # single table inheritance. - type: Mapped[ExecutionContentType] = mapped_column( - EnumText(ExecutionContentType, length=30), - nullable=False, - ) - - # `workflow_run_id` records the workflow execution which generates this content, correspond to - # `WorkflowRun.id`. - workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True) - - # `message_id` records the messages generated by the execution associated with this `ExecutionExtraContent`. - # It references to `Message.id`. - # - # For workflow execution, this field is `None`. - # - # For chatflow execution, `message_id`` is not None, and the following condition holds: - # - # The message referenced by `message_id` has `message.workflow_run_id == execution_extra_content.workflow_run_id` - # - message_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, index=True) - - -class HumanInputContent(ExecutionExtraContent): - """HumanInputContent is a concrete class that represents human input content. - It should only be initialized with the `new` class method.""" - - __mapper_args__ = { - "polymorphic_identity": ExecutionContentType.HUMAN_INPUT, - } - - # A relation to HumanInputForm table. - # - # While the form_id column is nullable in database (due to the nature of single table inheritance), - # the form_id field should not be null for a given `HumanInputContent` instance. - form_id: Mapped[str] = mapped_column(StringUUID, nullable=True) - - @classmethod - def new(cls, form_id: str, message_id: str | None) -> "HumanInputContent": - return cls(form_id=form_id, message_id=message_id) - - form: Mapped["HumanInputForm"] = relationship( - "HumanInputForm", - foreign_keys=[form_id], - uselist=False, - lazy="raise", - primaryjoin="foreign(HumanInputContent.form_id) == HumanInputForm.id", - ) diff --git a/api/models/human_input.py b/api/models/human_input.py deleted file mode 100644 index 5208461de1..0000000000 --- a/api/models/human_input.py +++ /dev/null @@ -1,237 +0,0 @@ -from datetime import datetime -from enum import StrEnum -from typing import Annotated, Literal, Self, final - -import sqlalchemy as sa -from pydantic import BaseModel, Field -from sqlalchemy.orm import Mapped, mapped_column, relationship - -from core.workflow.nodes.human_input.enums import ( - DeliveryMethodType, - HumanInputFormKind, - HumanInputFormStatus, -) -from libs.helper import generate_string - -from .base import Base, DefaultFieldsMixin -from .types import EnumText, StringUUID - -_token_length = 22 -# A 32-character string can store a base64-encoded value with 192 bits of entropy -# or a base62-encoded value with over 180 bits of entropy, providing sufficient -# uniqueness for most use cases. -_token_field_length = 32 -_email_field_length = 330 - - -def _generate_token() -> str: - return generate_string(_token_length) - - -class HumanInputForm(DefaultFieldsMixin, Base): - __tablename__ = "human_input_forms" - - tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - workflow_run_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - form_kind: Mapped[HumanInputFormKind] = mapped_column( - EnumText(HumanInputFormKind), - nullable=False, - default=HumanInputFormKind.RUNTIME, - ) - - # The human input node the current form corresponds to. - node_id: Mapped[str] = mapped_column(sa.String(60), nullable=False) - form_definition: Mapped[str] = mapped_column(sa.Text, nullable=False) - rendered_content: Mapped[str] = mapped_column(sa.Text, nullable=False) - status: Mapped[HumanInputFormStatus] = mapped_column( - EnumText(HumanInputFormStatus), - nullable=False, - default=HumanInputFormStatus.WAITING, - ) - - expiration_time: Mapped[datetime] = mapped_column( - sa.DateTime, - nullable=False, - ) - - # Submission-related fields (nullable until a submission happens). - selected_action_id: Mapped[str | None] = mapped_column(sa.String(200), nullable=True) - submitted_data: Mapped[str | None] = mapped_column(sa.Text, nullable=True) - submitted_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True) - submission_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - submission_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - - completed_by_recipient_id: Mapped[str | None] = mapped_column( - StringUUID, - nullable=True, - ) - - deliveries: Mapped[list["HumanInputDelivery"]] = relationship( - "HumanInputDelivery", - primaryjoin="HumanInputForm.id == foreign(HumanInputDelivery.form_id)", - uselist=True, - back_populates="form", - lazy="raise", - ) - completed_by_recipient: Mapped["HumanInputFormRecipient | None"] = relationship( - "HumanInputFormRecipient", - primaryjoin="HumanInputForm.completed_by_recipient_id == foreign(HumanInputFormRecipient.id)", - lazy="raise", - viewonly=True, - ) - - -class HumanInputDelivery(DefaultFieldsMixin, Base): - __tablename__ = "human_input_form_deliveries" - - form_id: Mapped[str] = mapped_column( - StringUUID, - nullable=False, - ) - delivery_method_type: Mapped[DeliveryMethodType] = mapped_column( - EnumText(DeliveryMethodType), - nullable=False, - ) - delivery_config_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - channel_payload: Mapped[str] = mapped_column(sa.Text, nullable=False) - - form: Mapped[HumanInputForm] = relationship( - "HumanInputForm", - uselist=False, - foreign_keys=[form_id], - primaryjoin="HumanInputDelivery.form_id == HumanInputForm.id", - back_populates="deliveries", - lazy="raise", - ) - - recipients: Mapped[list["HumanInputFormRecipient"]] = relationship( - "HumanInputFormRecipient", - primaryjoin="HumanInputDelivery.id == foreign(HumanInputFormRecipient.delivery_id)", - uselist=True, - back_populates="delivery", - # Require explicit preloading - lazy="raise", - ) - - -class RecipientType(StrEnum): - # EMAIL_MEMBER member means that the - EMAIL_MEMBER = "email_member" - EMAIL_EXTERNAL = "email_external" - # STANDALONE_WEB_APP is used by the standalone web app. - # - # It's not used while running workflows / chatflows containing HumanInput - # node inside console. - STANDALONE_WEB_APP = "standalone_web_app" - # CONSOLE is used while running workflows / chatflows containing HumanInput - # node inside console. (E.G. running installed apps or debugging workflows / chatflows) - CONSOLE = "console" - # BACKSTAGE is used for backstage input inside console. - BACKSTAGE = "backstage" - - -@final -class EmailMemberRecipientPayload(BaseModel): - TYPE: Literal[RecipientType.EMAIL_MEMBER] = RecipientType.EMAIL_MEMBER - user_id: str - - # The `email` field here is only used for mail sending. - email: str - - -@final -class EmailExternalRecipientPayload(BaseModel): - TYPE: Literal[RecipientType.EMAIL_EXTERNAL] = RecipientType.EMAIL_EXTERNAL - email: str - - -@final -class StandaloneWebAppRecipientPayload(BaseModel): - TYPE: Literal[RecipientType.STANDALONE_WEB_APP] = RecipientType.STANDALONE_WEB_APP - - -@final -class ConsoleRecipientPayload(BaseModel): - TYPE: Literal[RecipientType.CONSOLE] = RecipientType.CONSOLE - account_id: str | None = None - - -@final -class BackstageRecipientPayload(BaseModel): - TYPE: Literal[RecipientType.BACKSTAGE] = RecipientType.BACKSTAGE - account_id: str | None = None - - -@final -class ConsoleDeliveryPayload(BaseModel): - type: Literal["console"] = "console" - internal: bool = True - - -RecipientPayload = Annotated[ - EmailMemberRecipientPayload - | EmailExternalRecipientPayload - | StandaloneWebAppRecipientPayload - | ConsoleRecipientPayload - | BackstageRecipientPayload, - Field(discriminator="TYPE"), -] - - -class HumanInputFormRecipient(DefaultFieldsMixin, Base): - __tablename__ = "human_input_form_recipients" - - form_id: Mapped[str] = mapped_column( - StringUUID, - nullable=False, - ) - delivery_id: Mapped[str] = mapped_column( - StringUUID, - nullable=False, - ) - recipient_type: Mapped["RecipientType"] = mapped_column(EnumText(RecipientType), nullable=False) - recipient_payload: Mapped[str] = mapped_column(sa.Text, nullable=False) - - # Token primarily used for authenticated resume links (email, etc.). - access_token: Mapped[str | None] = mapped_column( - sa.VARCHAR(_token_field_length), - nullable=False, - default=_generate_token, - unique=True, - ) - - delivery: Mapped[HumanInputDelivery] = relationship( - "HumanInputDelivery", - uselist=False, - foreign_keys=[delivery_id], - back_populates="recipients", - primaryjoin="HumanInputFormRecipient.delivery_id == HumanInputDelivery.id", - # Require explicit preloading - lazy="raise", - ) - - form: Mapped[HumanInputForm] = relationship( - "HumanInputForm", - uselist=False, - foreign_keys=[form_id], - primaryjoin="HumanInputFormRecipient.form_id == HumanInputForm.id", - # Require explicit preloading - lazy="raise", - ) - - @classmethod - def new( - cls, - form_id: str, - delivery_id: str, - payload: RecipientPayload, - ) -> Self: - recipient_model = cls( - form_id=form_id, - delivery_id=delivery_id, - recipient_type=payload.TYPE, - recipient_payload=payload.model_dump_json(), - access_token=_generate_token(), - ) - return recipient_model diff --git a/api/models/model.py b/api/models/model.py index c12362f359..c1c6e04ce9 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -3,7 +3,7 @@ from __future__ import annotations import json import re import uuid -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from datetime import datetime from decimal import Decimal from enum import StrEnum, auto @@ -943,7 +943,6 @@ class Conversation(Base): WorkflowExecutionStatus.FAILED: 0, WorkflowExecutionStatus.STOPPED: 0, WorkflowExecutionStatus.PARTIAL_SUCCEEDED: 0, - WorkflowExecutionStatus.PAUSED: 0, } for message in messages: @@ -964,7 +963,6 @@ class Conversation(Base): "success": status_counts[WorkflowExecutionStatus.SUCCEEDED], "failed": status_counts[WorkflowExecutionStatus.FAILED], "partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED], - "paused": status_counts[WorkflowExecutionStatus.PAUSED], } @property @@ -1347,14 +1345,6 @@ class Message(Base): db.session.commit() return result - # TODO(QuantumGhost): dirty hacks, fix this later. - def set_extra_contents(self, contents: Sequence[dict[str, Any]]) -> None: - self._extra_contents = list(contents) - - @property - def extra_contents(self) -> list[dict[str, Any]]: - return getattr(self, "_extra_contents", []) - @property def workflow_run(self): if self.workflow_run_id: diff --git a/api/models/workflow.py b/api/models/workflow.py index 94e0881bd1..df83228c2a 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -20,7 +20,6 @@ from sqlalchemy import ( select, ) from sqlalchemy.orm import Mapped, declared_attr, mapped_column -from typing_extensions import deprecated from core.file.constants import maybe_file_object from core.file.models import File @@ -31,7 +30,7 @@ from core.workflow.constants import ( SYSTEM_VARIABLE_NODE_ID, ) from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from core.workflow.enums import NodeType, WorkflowExecutionStatus +from core.workflow.enums import NodeType from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type from libs.datetime_utils import naive_utc_now @@ -406,11 +405,6 @@ class Workflow(Base): # bug return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) @property - @deprecated( - "This property is not accurate for determining if a workflow is published as a tool." - "It only checks if there's a WorkflowToolProvider for the app, " - "not if this specific workflow version is the one being used by the tool." - ) def tool_published(self) -> bool: """ DEPRECATED: This property is not accurate for determining if a workflow is published as a tool. @@ -613,16 +607,13 @@ class WorkflowRun(Base): version: Mapped[str] = mapped_column(String(255)) graph: Mapped[str | None] = mapped_column(LongText) inputs: Mapped[str | None] = mapped_column(LongText) - status: Mapped[WorkflowExecutionStatus] = mapped_column( - EnumText(WorkflowExecutionStatus, length=255), - nullable=False, - ) + status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[str | None] = mapped_column(LongText, default="{}") error: Mapped[str | None] = mapped_column(LongText) elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) - created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255)) # account, end_user + created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) finished_at: Mapped[datetime | None] = mapped_column(DateTime) @@ -638,13 +629,11 @@ class WorkflowRun(Base): ) @property - @deprecated("This method is retained for historical reasons; avoid using it if possible.") def created_by_account(self): created_by_role = CreatorUserRole(self.created_by_role) return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None @property - @deprecated("This method is retained for historical reasons; avoid using it if possible.") def created_by_end_user(self): from .model import EndUser @@ -664,7 +653,6 @@ class WorkflowRun(Base): return json.loads(self.outputs) if self.outputs else {} @property - @deprecated("This method is retained for historical reasons; avoid using it if possible.") def message(self): from .model import Message @@ -673,7 +661,6 @@ class WorkflowRun(Base): ) @property - @deprecated("This method is retained for historical reasons; avoid using it if possible.") def workflow(self): return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first() @@ -1874,12 +1861,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base): def to_entity(self) -> PauseReason: if self.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: - return HumanInputRequired( - form_id=self.form_id, - form_content="", - node_id=self.node_id, - node_title="", - ) + return HumanInputRequired(form_id=self.form_id, node_id=self.node_id) elif self.type_ == PauseReasonType.SCHEDULED_PAUSE: return SchedulingPause(message=self.message) else: diff --git a/api/pyproject.toml b/api/pyproject.toml index 16395573f4..482dd4c8ad 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "numpy~=1.26.4", "openpyxl~=3.1.5", "opik~=1.8.72", - "litellm==1.77.1", # Pinned to avoid madoka dependency issue + "litellm==1.77.1", # Pinned to avoid madoka dependency issue "opentelemetry-api==1.27.0", "opentelemetry-distro==0.48b0", "opentelemetry-exporter-otlp==1.27.0", @@ -230,23 +230,3 @@ vdb = [ "mo-vector~=0.1.13", "mysql-connector-python>=9.3.0", ] - -[tool.mypy] - -[[tool.mypy.overrides]] -# targeted ignores for current type-check errors -# TODO(QuantumGhost): suppress type errors in HITL related code. -# fix the type error later -module = [ - "configs.middleware.cache.redis_pubsub_config", - "extensions.ext_redis", - "tasks.workflow_execution_tasks", - "core.workflow.nodes.base.node", - "services.human_input_delivery_test_service", - "core.app.apps.advanced_chat.app_generator", - "controllers.console.human_input_form", - "controllers.console.app.workflow_run", - "repositories.sqlalchemy_api_workflow_node_execution_repository", - "extensions.logstore.repositories.logstore_api_workflow_run_repository", -] -ignore_errors = true diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 6446eb0d6e..5b3f635301 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -10,7 +10,6 @@ tenant_id, app_id, triggered_from, etc., which are not part of the core domain m """ from collections.abc import Sequence -from dataclasses import dataclass from datetime import datetime from typing import Protocol @@ -20,27 +19,6 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload -@dataclass(frozen=True) -class WorkflowNodeExecutionSnapshot: - """ - Minimal snapshot of workflow node execution for stream recovery. - - Only includes fields required by snapshot events. - """ - - execution_id: str # Unique execution identifier (node_execution_id or row id). - node_id: str # Workflow graph node id. - node_type: str # Workflow graph node type (e.g. "human-input"). - title: str # Human-friendly node title. - index: int # Execution order index within the workflow run. - status: str # Execution status (running/succeeded/failed/paused). - elapsed_time: float # Execution elapsed time in seconds. - created_at: datetime # Execution created timestamp. - finished_at: datetime | None # Execution finished timestamp. - iteration_id: str | None = None # Iteration id from execution metadata, if any. - loop_id: str | None = None # Loop id from execution metadata, if any. - - class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol): """ Protocol for service-layer operations on WorkflowNodeExecutionModel. @@ -101,8 +79,6 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr Args: tenant_id: The tenant identifier app_id: The application identifier - workflow_id: The workflow identifier - triggered_from: The workflow trigger source workflow_run_id: The workflow run identifier Returns: @@ -110,27 +86,6 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr """ ... - def get_execution_snapshots_by_workflow_run( - self, - tenant_id: str, - app_id: str, - workflow_id: str, - triggered_from: str, - workflow_run_id: str, - ) -> Sequence[WorkflowNodeExecutionSnapshot]: - """ - Get minimal snapshots for node executions in a workflow run. - - Args: - tenant_id: The tenant identifier - app_id: The application identifier - workflow_run_id: The workflow run identifier - - Returns: - A sequence of WorkflowNodeExecutionSnapshot ordered by creation time - """ - ... - def get_execution_by_id( self, execution_id: str, diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 17e01a6e18..1d3954571f 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -432,13 +432,6 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): # while creating pause. ... - def get_workflow_pause(self, workflow_run_id: str) -> WorkflowPauseEntity | None: - """Retrieve the current pause for a workflow execution. - - If there is no current pause, this method would return `None`. - """ - ... - def resume_workflow_pause( self, workflow_run_id: str, @@ -634,19 +627,3 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): [{"date": "2024-01-01", "interactions": 2.5}, ...] """ ... - - def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None: - """ - Get a specific workflow run by its id and the associated tenant id. - - This function does not apply application isolation. It should only be used when - the application identifier is not available. - - Args: - tenant_id: Tenant identifier for multi-tenant isolation - run_id: Workflow run identifier - - Returns: - WorkflowRun object if found, None otherwise - """ - ... diff --git a/api/repositories/entities/workflow_pause.py b/api/repositories/entities/workflow_pause.py index a3c4039aaa..b970f39816 100644 --- a/api/repositories/entities/workflow_pause.py +++ b/api/repositories/entities/workflow_pause.py @@ -63,12 +63,6 @@ class WorkflowPauseEntity(ABC): """ pass - @property - @abstractmethod - def paused_at(self) -> datetime: - """`paused_at` returns the creation time of the pause.""" - pass - @abstractmethod def get_pause_reasons(self) -> Sequence[PauseReason]: """ @@ -76,5 +70,7 @@ class WorkflowPauseEntity(ABC): Returns a sequence of `PauseReason` objects describing the specific nodes and reasons for which the workflow execution was paused. + This information is related to, but distinct from, the `PauseReason` type + defined in `api/core/workflow/entities/pause_reason.py`. """ ... diff --git a/api/repositories/execution_extra_content_repository.py b/api/repositories/execution_extra_content_repository.py deleted file mode 100644 index 72b5443d2c..0000000000 --- a/api/repositories/execution_extra_content_repository.py +++ /dev/null @@ -1,13 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from typing import Protocol - -from core.entities.execution_extra_content import ExecutionExtraContentDomainModel - - -class ExecutionExtraContentRepository(Protocol): - def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]: ... - - -__all__ = ["ExecutionExtraContentRepository"] diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 6c696b6478..b19cc73bd1 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -5,7 +5,6 @@ This module provides a concrete implementation of the service repository protoco using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. """ -import json from collections.abc import Sequence from datetime import datetime from typing import cast @@ -14,12 +13,11 @@ from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker -from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload -from repositories.api_workflow_node_execution_repository import ( - DifyAPIWorkflowNodeExecutionRepository, - WorkflowNodeExecutionSnapshot, +from models.workflow import ( + WorkflowNodeExecutionModel, + WorkflowNodeExecutionOffload, ) +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository): @@ -81,7 +79,6 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut WorkflowNodeExecutionModel.app_id == app_id, WorkflowNodeExecutionModel.workflow_id == workflow_id, WorkflowNodeExecutionModel.node_id == node_id, - WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED, ) .order_by(desc(WorkflowNodeExecutionModel.created_at)) .limit(1) @@ -120,80 +117,6 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut with self._session_maker() as session: return session.execute(stmt).scalars().all() - def get_execution_snapshots_by_workflow_run( - self, - tenant_id: str, - app_id: str, - workflow_id: str, - triggered_from: str, - workflow_run_id: str, - ) -> Sequence[WorkflowNodeExecutionSnapshot]: - stmt = ( - select( - WorkflowNodeExecutionModel.id, - WorkflowNodeExecutionModel.node_execution_id, - WorkflowNodeExecutionModel.node_id, - WorkflowNodeExecutionModel.node_type, - WorkflowNodeExecutionModel.title, - WorkflowNodeExecutionModel.index, - WorkflowNodeExecutionModel.status, - WorkflowNodeExecutionModel.elapsed_time, - WorkflowNodeExecutionModel.created_at, - WorkflowNodeExecutionModel.finished_at, - WorkflowNodeExecutionModel.execution_metadata, - ) - .where( - WorkflowNodeExecutionModel.tenant_id == tenant_id, - WorkflowNodeExecutionModel.app_id == app_id, - WorkflowNodeExecutionModel.workflow_id == workflow_id, - WorkflowNodeExecutionModel.triggered_from == triggered_from, - WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, - ) - .order_by( - asc(WorkflowNodeExecutionModel.created_at), - asc(WorkflowNodeExecutionModel.index), - ) - ) - - with self._session_maker() as session: - rows = session.execute(stmt).all() - - return [self._row_to_snapshot(row) for row in rows] - - @staticmethod - def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot: - metadata: dict[str, object] = {} - execution_metadata = getattr(row, "execution_metadata", None) - if execution_metadata: - try: - metadata = json.loads(execution_metadata) - except json.JSONDecodeError: - metadata = {} - iteration_id = metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID.value) - loop_id = metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID.value) - execution_id = getattr(row, "node_execution_id", None) or row.id - elapsed_time = getattr(row, "elapsed_time", None) - created_at = row.created_at - finished_at = getattr(row, "finished_at", None) - if elapsed_time is None: - if finished_at is not None and created_at is not None: - elapsed_time = (finished_at - created_at).total_seconds() - else: - elapsed_time = 0.0 - return WorkflowNodeExecutionSnapshot( - execution_id=str(execution_id), - node_id=row.node_id, - node_type=row.node_type, - title=row.title, - index=row.index, - status=row.status, - elapsed_time=float(elapsed_time), - created_at=created_at, - finished_at=finished_at, - iteration_id=str(iteration_id) if iteration_id else None, - loop_id=str(loop_id) if loop_id else None, - ) - def get_execution_by_id( self, execution_id: str, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 00cb979e17..d5214be042 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -19,7 +19,6 @@ Implementation Notes: - Maintains data consistency with proper transaction handling """ -import json import logging import uuid from collections.abc import Callable, Sequence @@ -28,14 +27,12 @@ from decimal import Decimal from typing import Any, cast import sqlalchemy as sa -from pydantic import ValidationError from sqlalchemy import and_, delete, func, null, or_, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker -from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause from core.workflow.enums import WorkflowExecutionStatus, WorkflowType -from core.workflow.nodes.human_input.entities import FormDefinition from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date @@ -43,7 +40,6 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom -from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity @@ -61,67 +57,6 @@ class _WorkflowRunError(Exception): pass -def _select_recipient_token( - recipients: Sequence[HumanInputFormRecipient], - recipient_type: RecipientType, -) -> str | None: - for recipient in recipients: - if recipient.recipient_type == recipient_type and recipient.access_token: - return recipient.access_token - return None - - -def _build_human_input_required_reason( - reason_model: WorkflowPauseReason, - form_model: HumanInputForm | None, - recipients: Sequence[HumanInputFormRecipient], -) -> HumanInputRequired: - form_content = "" - inputs = [] - actions = [] - display_in_ui = False - resolved_default_values: dict[str, Any] = {} - node_title = "Human Input" - form_id = reason_model.form_id - node_id = reason_model.node_id - if form_model is not None: - form_id = form_model.id - node_id = form_model.node_id or node_id - try: - definition_payload = json.loads(form_model.form_definition) - if "expiration_time" not in definition_payload: - definition_payload["expiration_time"] = form_model.expiration_time - definition = FormDefinition.model_validate(definition_payload) - except ValidationError: - definition = None - - if definition is not None: - form_content = definition.form_content - inputs = list(definition.inputs) - actions = list(definition.user_actions) - display_in_ui = bool(definition.display_in_ui) - resolved_default_values = dict(definition.default_values) - node_title = definition.node_title or node_title - - form_token = ( - _select_recipient_token(recipients, RecipientType.BACKSTAGE) - or _select_recipient_token(recipients, RecipientType.CONSOLE) - or _select_recipient_token(recipients, RecipientType.STANDALONE_WEB_APP) - ) - - return HumanInputRequired( - form_id=form_id, - form_content=form_content, - inputs=inputs, - actions=actions, - display_in_ui=display_in_ui, - node_id=node_id, - node_title=node_title, - form_token=form_token, - resolved_default_values=resolved_default_values, - ) - - class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): """ SQLAlchemy implementation of APIWorkflowRunRepository. @@ -741,11 +676,9 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): raise ValueError(f"WorkflowRun not found: {workflow_run_id}") # Check if workflow is in RUNNING status - # TODO(QuantumGhost): It seems that the persistence of `WorkflowRun.status` - # happens before the execution of GraphLayer - if workflow_run.status not in {WorkflowExecutionStatus.RUNNING, WorkflowExecutionStatus.PAUSED}: + if workflow_run.status != WorkflowExecutionStatus.RUNNING: raise _WorkflowRunError( - f"Only WorkflowRun with RUNNING or PAUSED status can be paused, " + f"Only WorkflowRun with RUNNING status can be paused, " f"workflow_run_id={workflow_run_id}, current_status={workflow_run.status}" ) # @@ -796,48 +729,13 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id) - return _PrivateWorkflowPauseEntity( - pause_model=pause_model, - reason_models=pause_reason_models, - pause_reasons=pause_reasons, - ) + return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reason_models) def _get_reasons_by_pause_id(self, session: Session, pause_id: str): reason_stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id == pause_id) pause_reason_models = session.scalars(reason_stmt).all() return pause_reason_models - def _hydrate_pause_reasons( - self, - session: Session, - pause_reason_models: Sequence[WorkflowPauseReason], - ) -> list[PauseReason]: - form_ids = [ - reason.form_id - for reason in pause_reason_models - if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED and reason.form_id - ] - form_models: dict[str, HumanInputForm] = {} - recipient_models_by_form: dict[str, list[HumanInputFormRecipient]] = {} - if form_ids: - form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids)) - for form in session.scalars(form_stmt).all(): - form_models[form.id] = form - - recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) - for recipient in session.scalars(recipient_stmt).all(): - recipient_models_by_form.setdefault(recipient.form_id, []).append(recipient) - - pause_reasons: list[PauseReason] = [] - for reason in pause_reason_models: - if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: - form_model = form_models.get(reason.form_id) - recipients = recipient_models_by_form.get(reason.form_id, []) - pause_reasons.append(_build_human_input_required_reason(reason, form_model, recipients)) - else: - pause_reasons.append(reason.to_entity()) - return pause_reasons - def get_workflow_pause( self, workflow_run_id: str, @@ -869,12 +767,14 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): if pause_model is None: return None pause_reason_models = self._get_reasons_by_pause_id(session, pause_model.id) - pause_reasons = self._hydrate_pause_reasons(session, pause_reason_models) + + human_input_form: list[Any] = [] + # TODO(QuantumGhost): query human_input_forms model and rebuild PauseReason return _PrivateWorkflowPauseEntity( pause_model=pause_model, reason_models=pause_reason_models, - pause_reasons=pause_reasons, + human_input_form=human_input_form, ) def resume_workflow_pause( @@ -928,10 +828,10 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}") pause_reasons = self._get_reasons_by_pause_id(session, pause_model.id) - hydrated_pause_reasons = self._hydrate_pause_reasons(session, pause_reasons) # Mark as resumed pause_model.resumed_at = naive_utc_now() + workflow_run.pause_id = None # type: ignore workflow_run.status = WorkflowExecutionStatus.RUNNING session.add(pause_model) @@ -939,11 +839,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id) - return _PrivateWorkflowPauseEntity( - pause_model=pause_model, - reason_models=pause_reasons, - pause_reasons=hydrated_pause_reasons, - ) + return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reasons) def delete_workflow_pause( self, @@ -1269,15 +1165,6 @@ GROUP BY return cast(list[AverageInteractionStats], response_data) - def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None: - """Get a specific workflow run by its id and the associated tenant id.""" - with self._session_maker() as session: - stmt = select(WorkflowRun).where( - WorkflowRun.tenant_id == tenant_id, - WorkflowRun.id == run_id, - ) - return session.scalar(stmt) - class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): """ @@ -1292,12 +1179,10 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): *, pause_model: WorkflowPause, reason_models: Sequence[WorkflowPauseReason], - pause_reasons: Sequence[PauseReason] | None = None, human_input_form: Sequence = (), ) -> None: self._pause_model = pause_model self._reason_models = reason_models - self._pause_reasons = pause_reasons self._cached_state: bytes | None = None self._human_input_form = human_input_form @@ -1334,10 +1219,4 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): return self._pause_model.resumed_at def get_pause_reasons(self) -> Sequence[PauseReason]: - if self._pause_reasons is not None: - return list(self._pause_reasons) return [reason.to_entity() for reason in self._reason_models] - - @property - def paused_at(self) -> datetime: - return self._pause_model.created_at diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py deleted file mode 100644 index 5a2c0ea46f..0000000000 --- a/api/repositories/sqlalchemy_execution_extra_content_repository.py +++ /dev/null @@ -1,200 +0,0 @@ -from __future__ import annotations - -import json -import logging -import re -from collections import defaultdict -from collections.abc import Sequence -from typing import Any - -from sqlalchemy import select -from sqlalchemy.orm import Session, selectinload, sessionmaker - -from core.entities.execution_extra_content import ( - ExecutionExtraContentDomainModel, - HumanInputFormDefinition, - HumanInputFormSubmissionData, -) -from core.entities.execution_extra_content import ( - HumanInputContent as HumanInputContentDomainModel, -) -from core.workflow.nodes.human_input.entities import FormDefinition -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from models.execution_extra_content import ( - ExecutionExtraContent as ExecutionExtraContentModel, -) -from models.execution_extra_content import ( - HumanInputContent as HumanInputContentModel, -) -from models.human_input import HumanInputFormRecipient, RecipientType -from repositories.execution_extra_content_repository import ExecutionExtraContentRepository - -logger = logging.getLogger(__name__) - -_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") - - -def _extract_output_field_names(form_content: str) -> list[str]: - if not form_content: - return [] - return [match.group("field_name") for match in _OUTPUT_VARIABLE_PATTERN.finditer(form_content)] - - -class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository): - def __init__(self, session_maker: sessionmaker[Session]): - self._session_maker = session_maker - - def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]: - if not message_ids: - return [] - - grouped_contents: dict[str, list[ExecutionExtraContentDomainModel]] = { - message_id: [] for message_id in message_ids - } - - stmt = ( - select(ExecutionExtraContentModel) - .where(ExecutionExtraContentModel.message_id.in_(message_ids)) - .options(selectinload(HumanInputContentModel.form)) - .order_by(ExecutionExtraContentModel.created_at.asc()) - ) - - with self._session_maker() as session: - results = session.scalars(stmt).all() - - form_ids = { - content.form_id - for content in results - if isinstance(content, HumanInputContentModel) and content.form_id is not None - } - recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = defaultdict(list) - if form_ids: - recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) - recipients = session.scalars(recipient_stmt).all() - for recipient in recipients: - recipients_by_form_id[recipient.form_id].append(recipient) - else: - recipients_by_form_id = {} - - for content in results: - message_id = content.message_id - if not message_id or message_id not in grouped_contents: - continue - - domain_model = self._map_model_to_domain(content, recipients_by_form_id) - if domain_model is None: - continue - - grouped_contents[message_id].append(domain_model) - - return [grouped_contents[message_id] for message_id in message_ids] - - def _map_model_to_domain( - self, - model: ExecutionExtraContentModel, - recipients_by_form_id: dict[str, list[HumanInputFormRecipient]], - ) -> ExecutionExtraContentDomainModel | None: - if isinstance(model, HumanInputContentModel): - return self._map_human_input_content(model, recipients_by_form_id) - - logger.debug("Unsupported execution extra content type encountered: %s", model.type) - return None - - def _map_human_input_content( - self, - model: HumanInputContentModel, - recipients_by_form_id: dict[str, list[HumanInputFormRecipient]], - ) -> HumanInputContentDomainModel | None: - form = model.form - if form is None: - logger.warning("HumanInputContent(id=%s) has no associated form loaded", model.id) - return None - - try: - definition_payload = json.loads(form.form_definition) - if "expiration_time" not in definition_payload: - definition_payload["expiration_time"] = form.expiration_time - form_definition = FormDefinition.model_validate(definition_payload) - except ValueError: - logger.warning("Failed to load form definition for HumanInputContent(id=%s)", model.id) - return None - node_title = form_definition.node_title or form.node_id - display_in_ui = bool(form_definition.display_in_ui) - - submitted = form.submitted_at is not None or form.status == HumanInputFormStatus.SUBMITTED - if not submitted: - form_token = self._resolve_form_token(recipients_by_form_id.get(form.id, [])) - return HumanInputContentDomainModel( - workflow_run_id=model.workflow_run_id, - submitted=False, - form_definition=HumanInputFormDefinition( - form_id=form.id, - node_id=form.node_id, - node_title=node_title, - form_content=form.rendered_content, - inputs=form_definition.inputs, - actions=form_definition.user_actions, - display_in_ui=display_in_ui, - form_token=form_token, - resolved_default_values=form_definition.default_values, - expiration_time=int(form.expiration_time.timestamp()), - ), - ) - - selected_action_id = form.selected_action_id - if not selected_action_id: - logger.warning("HumanInputContent(id=%s) form has no selected action", model.id) - return None - - action_text = next( - (action.title for action in form_definition.user_actions if action.id == selected_action_id), - selected_action_id, - ) - - submitted_data: dict[str, Any] = {} - if form.submitted_data: - try: - submitted_data = json.loads(form.submitted_data) - except ValueError: - logger.warning("Failed to load submitted data for HumanInputContent(id=%s)", model.id) - return None - - rendered_content = HumanInputNode.render_form_content_with_outputs( - form.rendered_content, - submitted_data, - _extract_output_field_names(form_definition.form_content), - ) - - return HumanInputContentDomainModel( - workflow_run_id=model.workflow_run_id, - submitted=True, - form_submission_data=HumanInputFormSubmissionData( - node_id=form.node_id, - node_title=node_title, - rendered_content=rendered_content, - action_id=selected_action_id, - action_text=action_text, - ), - ) - - @staticmethod - def _resolve_form_token(recipients: Sequence[HumanInputFormRecipient]) -> str | None: - console_recipient = next( - (recipient for recipient in recipients if recipient.recipient_type == RecipientType.CONSOLE), - None, - ) - if console_recipient and console_recipient.access_token: - return console_recipient.access_token - - web_app_recipient = next( - (recipient for recipient in recipients if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP), - None, - ) - if web_app_recipient and web_app_recipient.access_token: - return web_app_recipient.access_token - - return None - - -__all__ = ["SQLAlchemyExecutionExtraContentRepository"] diff --git a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py index 1f6740b066..f3dc4cd60b 100644 --- a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py +++ b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @@ -92,16 +92,6 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): return list(self.session.scalars(query).all()) - def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None: - """Get the trigger log associated with a workflow run.""" - query = ( - select(WorkflowTriggerLog) - .where(WorkflowTriggerLog.workflow_run_id == workflow_run_id) - .order_by(WorkflowTriggerLog.created_at.desc()) - .limit(1) - ) - return self.session.scalar(query) - def delete_by_run_ids(self, run_ids: Sequence[str]) -> int: """ Delete trigger logs associated with the given workflow run ids. diff --git a/api/repositories/workflow_trigger_log_repository.py b/api/repositories/workflow_trigger_log_repository.py index 7f9e6b7b68..b0009e398d 100644 --- a/api/repositories/workflow_trigger_log_repository.py +++ b/api/repositories/workflow_trigger_log_repository.py @@ -110,18 +110,6 @@ class WorkflowTriggerLogRepository(Protocol): """ ... - def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None: - """ - Retrieve a trigger log associated with a specific workflow run. - - Args: - workflow_run_id: Identifier of the workflow run - - Returns: - The matching WorkflowTriggerLog if present, None otherwise - """ - ... - def delete_by_run_ids(self, run_ids: Sequence[str]) -> int: """ Delete trigger logs for workflow run IDs. diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 9400362605..0f42c99246 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -44,7 +44,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:" IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB -CURRENT_DSL_VERSION = "0.6.0" +CURRENT_DSL_VERSION = "0.5.0" class ImportMode(StrEnum): diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 0c27c403f8..ce85f2e914 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -1,9 +1,7 @@ from __future__ import annotations -import logging -import threading import uuid -from collections.abc import Callable, Generator, Mapping +from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, Any, Union from configs import dify_config @@ -11,63 +9,22 @@ from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.completion.app_generator import CompletionAppGenerator -from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting import RateLimit -from core.app.features.rate_limiting.rate_limit import rate_limit_context -from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig -from core.db import session_factory from enums.quota_type import QuotaType, unlimited from extensions.otel import AppGenerateHandler, trace_span from models.model import Account, App, AppMode, EndUser -from models.workflow import Workflow, WorkflowRun +from models.workflow import Workflow from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError from services.workflow_service import WorkflowService -from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task - -logger = logging.getLogger(__name__) - -SSE_TASK_START_FALLBACK_MS = 200 if TYPE_CHECKING: from controllers.console.app.workflow import LoopNodeRunPayload class AppGenerateService: - @staticmethod - def _build_streaming_task_on_subscribe(start_task: Callable[[], None]) -> Callable[[], None]: - started = False - lock = threading.Lock() - - def _try_start() -> bool: - nonlocal started - with lock: - if started: - return True - try: - start_task() - except Exception: - logger.exception("Failed to enqueue streaming task") - return False - started = True - return True - - # XXX(QuantumGhost): dirty hacks to avoid a race between publisher and SSE subscriber. - # The Celery task may publish the first event before the API side actually subscribes, - # causing an "at most once" drop with Redis Pub/Sub. We start the task on subscribe, - # but also use a short fallback timer so the task still runs if the client never consumes. - timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start) - timer.daemon = True - timer.start() - - def _on_subscribe() -> None: - if _try_start(): - timer.cancel() - - return _on_subscribe - @classmethod @trace_span(AppGenerateHandler) def generate( @@ -131,29 +88,15 @@ class AppGenerateService: elif app_model.mode == AppMode.ADVANCED_CHAT: workflow_id = args.get("workflow_id") workflow = cls._get_workflow(app_model, invoke_from, workflow_id) - with rate_limit_context(rate_limit, request_id): - payload = AppExecutionParams.new( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - streaming=streaming, - call_depth=0, - ) - payload_json = payload.model_dump_json() - - def on_subscribe(): - workflow_based_app_execution_task.delay(payload_json) - - on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe) - generator = AdvancedChatAppGenerator() return rate_limit.generate( - generator.convert_to_event_stream( - generator.retrieve_events( - AppMode.ADVANCED_CHAT, - payload.workflow_run_id, - on_subscribe=on_subscribe, + AdvancedChatAppGenerator.convert_to_event_stream( + AdvancedChatAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + streaming=streaming, ), ), request_id=request_id, @@ -161,40 +104,6 @@ class AppGenerateService: elif app_model.mode == AppMode.WORKFLOW: workflow_id = args.get("workflow_id") workflow = cls._get_workflow(app_model, invoke_from, workflow_id) - if streaming: - with rate_limit_context(rate_limit, request_id): - payload = AppExecutionParams.new( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - streaming=True, - call_depth=0, - root_node_id=root_node_id, - workflow_run_id=str(uuid.uuid4()), - ) - payload_json = payload.model_dump_json() - - def on_subscribe(): - workflow_based_app_execution_task.delay(payload_json) - - on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe) - return rate_limit.generate( - WorkflowAppGenerator.convert_to_event_stream( - MessageBasedAppGenerator.retrieve_events( - AppMode.WORKFLOW, - payload.workflow_run_id, - on_subscribe=on_subscribe, - ), - ), - request_id, - ) - - pause_config = PauseStateLayerConfig( - session_factory=session_factory.get_session_maker(), - state_owner_user_id=workflow.created_by, - ) return rate_limit.generate( WorkflowAppGenerator.convert_to_event_stream( WorkflowAppGenerator().generate( @@ -203,10 +112,9 @@ class AppGenerateService: user=user, args=args, invoke_from=invoke_from, - streaming=False, + streaming=streaming, root_node_id=root_node_id, call_depth=0, - pause_state_config=pause_config, ), ), request_id, @@ -340,19 +248,3 @@ class AppGenerateService: raise ValueError("Workflow not published") return workflow - - @classmethod - def get_response_generator( - cls, - app_model: App, - workflow_run: WorkflowRun, - ): - if workflow_run.status.is_ended(): - # TODO(QuantumGhost): handled the ended scenario. - pass - - generator = AdvancedChatAppGenerator() - - return generator.convert_to_event_stream( - generator.retrieve_events(AppMode(app_model.mode), workflow_run.id), - ) diff --git a/api/services/audio_service.py b/api/services/audio_service.py index a95361cebd..41ee9c88aa 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -136,7 +136,7 @@ class AudioService: message = db.session.query(Message).where(Message.id == message_id).first() if message is None: return None - if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}: + if message.answer == "" and message.status == MessageStatus.NORMAL: return None else: diff --git a/api/services/feature_service.py b/api/services/feature_service.py index fda3a15144..d94ae49d91 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -138,8 +138,6 @@ class FeatureModel(BaseModel): is_allow_transfer_workspace: bool = True trigger_event: Quota = Quota(usage=0, limit=3000, reset_date=0) api_rate_limit: Quota = Quota(usage=0, limit=5000, reset_date=0) - # Controls whether email delivery is allowed for HumanInput nodes. - human_input_email_delivery_enabled: bool = False # pydantic configs model_config = ConfigDict(protected_namespaces=()) knowledge_pipeline: KnowledgePipeline = KnowledgePipeline() @@ -193,11 +191,6 @@ class FeatureService: features.knowledge_pipeline.publish_enabled = True cls._fulfill_params_from_workspace_info(features, tenant_id) - features.human_input_email_delivery_enabled = cls._resolve_human_input_email_delivery_enabled( - features=features, - tenant_id=tenant_id, - ) - return features @classmethod @@ -210,17 +203,6 @@ class FeatureService: knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", CloudPlan.SANDBOX) return knowledge_rate_limit - @classmethod - def _resolve_human_input_email_delivery_enabled(cls, *, features: FeatureModel, tenant_id: str | None) -> bool: - if dify_config.ENTERPRISE_ENABLED or not dify_config.BILLING_ENABLED: - return True - if not tenant_id: - return False - return features.billing.enabled and features.billing.subscription.plan in ( - CloudPlan.PROFESSIONAL, - CloudPlan.TEAM, - ) - @classmethod def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel: system_features = SystemFeatureModel() diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py deleted file mode 100644 index ff37ff098f..0000000000 --- a/api/services/human_input_delivery_test_service.py +++ /dev/null @@ -1,249 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass, field -from enum import StrEnum -from typing import Protocol - -from sqlalchemy import Engine, select -from sqlalchemy.orm import sessionmaker - -from configs import dify_config -from core.workflow.nodes.human_input.entities import ( - DeliveryChannelConfig, - EmailDeliveryConfig, - EmailDeliveryMethod, - ExternalRecipient, - MemberRecipient, -) -from core.workflow.runtime import VariablePool -from extensions.ext_database import db -from extensions.ext_mail import mail -from libs.email_template_renderer import render_email_template -from models import Account, TenantAccountJoin -from services.feature_service import FeatureService - - -class DeliveryTestStatus(StrEnum): - OK = "ok" - FAILED = "failed" - - -@dataclass(frozen=True) -class DeliveryTestEmailRecipient: - email: str - form_token: str - - -@dataclass(frozen=True) -class DeliveryTestContext: - tenant_id: str - app_id: str - node_id: str - node_title: str | None - rendered_content: str - template_vars: dict[str, str] = field(default_factory=dict) - recipients: list[DeliveryTestEmailRecipient] = field(default_factory=list) - variable_pool: VariablePool | None = None - - -@dataclass(frozen=True) -class DeliveryTestResult: - status: DeliveryTestStatus - delivered_to: list[str] = field(default_factory=list) - warnings: list[str] = field(default_factory=list) - - -class DeliveryTestError(Exception): - pass - - -class DeliveryTestUnsupportedError(DeliveryTestError): - pass - - -def _build_form_link(token: str | None) -> str | None: - if not token: - return None - base_url = dify_config.APP_WEB_URL - if not base_url: - return None - return f"{base_url.rstrip('/')}/form/{token}" - - -class DeliveryTestHandler(Protocol): - def supports(self, method: DeliveryChannelConfig) -> bool: ... - - def send_test( - self, - *, - context: DeliveryTestContext, - method: DeliveryChannelConfig, - ) -> DeliveryTestResult: ... - - -class DeliveryTestRegistry: - def __init__(self, handlers: list[DeliveryTestHandler] | None = None) -> None: - self._handlers = list(handlers or []) - - def register(self, handler: DeliveryTestHandler) -> None: - self._handlers.append(handler) - - def dispatch( - self, - *, - context: DeliveryTestContext, - method: DeliveryChannelConfig, - ) -> DeliveryTestResult: - for handler in self._handlers: - if handler.supports(method): - return handler.send_test(context=context, method=method) - raise DeliveryTestUnsupportedError("Delivery method does not support test send.") - - @classmethod - def default(cls) -> DeliveryTestRegistry: - return cls([EmailDeliveryTestHandler()]) - - -class HumanInputDeliveryTestService: - def __init__(self, registry: DeliveryTestRegistry | None = None) -> None: - self._registry = registry or DeliveryTestRegistry.default() - - def send_test( - self, - *, - context: DeliveryTestContext, - method: DeliveryChannelConfig, - ) -> DeliveryTestResult: - return self._registry.dispatch(context=context, method=method) - - -class EmailDeliveryTestHandler: - def __init__(self, session_factory: sessionmaker | Engine | None = None) -> None: - if session_factory is None: - session_factory = sessionmaker(bind=db.engine) - elif isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory - - def supports(self, method: DeliveryChannelConfig) -> bool: - return isinstance(method, EmailDeliveryMethod) - - def send_test( - self, - *, - context: DeliveryTestContext, - method: DeliveryChannelConfig, - ) -> DeliveryTestResult: - if not isinstance(method, EmailDeliveryMethod): - raise DeliveryTestUnsupportedError("Delivery method does not support test send.") - features = FeatureService.get_features(context.tenant_id) - if not features.human_input_email_delivery_enabled: - raise DeliveryTestError("Email delivery is not available for current plan.") - if not mail.is_inited(): - raise DeliveryTestError("Mail client is not initialized.") - - recipients = self._resolve_recipients( - tenant_id=context.tenant_id, - method=method, - ) - if not recipients: - raise DeliveryTestError("No recipients configured for delivery method.") - - delivered: list[str] = [] - for recipient_email in recipients: - substitutions = self._build_substitutions( - context=context, - recipient_email=recipient_email, - ) - subject = render_email_template(method.config.subject, substitutions) - templated_body = EmailDeliveryConfig.render_body_template( - body=method.config.body, - url=substitutions.get("form_link"), - variable_pool=context.variable_pool, - ) - body = render_email_template(templated_body, substitutions) - - mail.send( - to=recipient_email, - subject=subject, - html=body, - ) - delivered.append(recipient_email) - - return DeliveryTestResult(status=DeliveryTestStatus.OK, delivered_to=delivered) - - def _resolve_recipients(self, *, tenant_id: str, method: EmailDeliveryMethod) -> list[str]: - recipients = method.config.recipients - emails: list[str] = [] - member_user_ids: list[str] = [] - for recipient in recipients.items: - if isinstance(recipient, MemberRecipient): - member_user_ids.append(recipient.user_id) - elif isinstance(recipient, ExternalRecipient): - if recipient.email: - emails.append(recipient.email) - - if recipients.whole_workspace: - member_user_ids = [] - member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=None) - emails.extend(member_emails.values()) - elif member_user_ids: - member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=member_user_ids) - for user_id in member_user_ids: - email = member_emails.get(user_id) - if email: - emails.append(email) - - return list(dict.fromkeys([email for email in emails if email])) - - def _query_workspace_member_emails( - self, - *, - tenant_id: str, - user_ids: list[str] | None, - ) -> dict[str, str]: - if user_ids is None: - unique_ids = None - else: - unique_ids = {user_id for user_id in user_ids if user_id} - if not unique_ids: - return {} - - stmt = ( - select(Account.id, Account.email) - .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) - .where(TenantAccountJoin.tenant_id == tenant_id) - ) - if unique_ids is not None: - stmt = stmt.where(Account.id.in_(unique_ids)) - - with self._session_factory() as session: - rows = session.execute(stmt).all() - return dict(rows) - - @staticmethod - def _build_substitutions( - *, - context: DeliveryTestContext, - recipient_email: str, - ) -> dict[str, str]: - raw_values: dict[str, str | None] = { - "form_id": "", - "node_title": context.node_title, - "workflow_run_id": "", - "form_token": "", - "form_link": "", - "form_content": context.rendered_content, - "recipient_email": recipient_email, - } - substitutions = {key: value or "" for key, value in raw_values.items()} - if context.template_vars: - substitutions.update({key: value for key, value in context.template_vars.items() if value is not None}) - token = next( - (recipient.form_token for recipient in context.recipients if recipient.email == recipient_email), - None, - ) - if token: - substitutions["form_token"] = token - substitutions["form_link"] = _build_form_link(token) or "" - return substitutions diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py deleted file mode 100644 index 76b6e6e0e6..0000000000 --- a/api/services/human_input_service.py +++ /dev/null @@ -1,250 +0,0 @@ -import logging -from collections.abc import Mapping -from datetime import datetime, timedelta -from typing import Any - -from sqlalchemy import Engine, select -from sqlalchemy.orm import Session, sessionmaker - -from configs import dify_config -from core.repositories.human_input_repository import ( - HumanInputFormRecord, - HumanInputFormSubmissionRepository, -) -from core.workflow.nodes.human_input.entities import ( - FormDefinition, - HumanInputSubmissionValidationError, - validate_human_input_submission, -) -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus -from libs.datetime_utils import ensure_naive_utc, naive_utc_now -from libs.exception import BaseHTTPException -from models.human_input import RecipientType -from models.model import App, AppMode -from repositories.factory import DifyAPIRepositoryFactory -from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE, resume_app_execution - - -class Form: - def __init__(self, record: HumanInputFormRecord): - self._record = record - - def get_definition(self) -> FormDefinition: - return self._record.definition - - @property - def submitted(self) -> bool: - return self._record.submitted - - @property - def id(self) -> str: - return self._record.form_id - - @property - def workflow_run_id(self) -> str | None: - """Workflow run id for runtime forms; None for delivery tests.""" - return self._record.workflow_run_id - - @property - def tenant_id(self) -> str: - return self._record.tenant_id - - @property - def app_id(self) -> str: - return self._record.app_id - - @property - def recipient_id(self) -> str | None: - return self._record.recipient_id - - @property - def recipient_type(self) -> RecipientType | None: - return self._record.recipient_type - - @property - def status(self) -> HumanInputFormStatus: - return self._record.status - - @property - def form_kind(self) -> HumanInputFormKind: - return self._record.form_kind - - @property - def created_at(self) -> "datetime": - return self._record.created_at - - @property - def expiration_time(self) -> "datetime": - return self._record.expiration_time - - -class HumanInputError(Exception): - pass - - -class FormSubmittedError(HumanInputError, BaseHTTPException): - error_code = "human_input_form_submitted" - description = "This form has already been submitted by another user, form_id={form_id}" - code = 412 - - def __init__(self, form_id: str): - template = self.description or "This form has already been submitted by another user, form_id={form_id}" - description = template.format(form_id=form_id) - super().__init__(description=description) - - -class FormNotFoundError(HumanInputError, BaseHTTPException): - error_code = "human_input_form_not_found" - code = 404 - - -class InvalidFormDataError(HumanInputError, BaseHTTPException): - error_code = "invalid_form_data" - code = 400 - - def __init__(self, description: str): - super().__init__(description=description) - - -class WebAppDeliveryNotEnabledError(HumanInputError, BaseException): - pass - - -class FormExpiredError(HumanInputError, BaseHTTPException): - error_code = "human_input_form_expired" - code = 412 - - def __init__(self, form_id: str): - super().__init__(description=f"This form has expired, form_id={form_id}") - - -logger = logging.getLogger(__name__) - - -class HumanInputService: - def __init__( - self, - session_factory: sessionmaker[Session] | Engine, - form_repository: HumanInputFormSubmissionRepository | None = None, - ): - if isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory - self._form_repository = form_repository or HumanInputFormSubmissionRepository(session_factory) - - def get_form_by_token(self, form_token: str) -> Form | None: - record = self._form_repository.get_by_token(form_token) - if record is None: - return None - return Form(record) - - def get_form_definition_by_token(self, recipient_type: RecipientType, form_token: str) -> Form | None: - form = self.get_form_by_token(form_token) - if form is None or form.recipient_type != recipient_type: - return None - self._ensure_not_submitted(form) - return form - - def get_form_definition_by_token_for_console(self, form_token: str) -> Form | None: - form = self.get_form_by_token(form_token) - if form is None or form.recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}: - return None - self._ensure_not_submitted(form) - return form - - def submit_form_by_token( - self, - recipient_type: RecipientType, - form_token: str, - selected_action_id: str, - form_data: Mapping[str, Any], - submission_end_user_id: str | None = None, - submission_user_id: str | None = None, - ): - form = self.get_form_by_token(form_token) - if form is None or form.recipient_type != recipient_type: - raise WebAppDeliveryNotEnabledError() - - self.ensure_form_active(form) - self._validate_submission(form=form, selected_action_id=selected_action_id, form_data=form_data) - - result = self._form_repository.mark_submitted( - form_id=form.id, - recipient_id=form.recipient_id, - selected_action_id=selected_action_id, - form_data=form_data, - submission_user_id=submission_user_id, - submission_end_user_id=submission_end_user_id, - ) - - if result.form_kind != HumanInputFormKind.RUNTIME: - return - if result.workflow_run_id is None: - return - self.enqueue_resume(result.workflow_run_id) - - def ensure_form_active(self, form: Form) -> None: - if form.submitted: - raise FormSubmittedError(form.id) - if form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}: - raise FormExpiredError(form.id) - now = naive_utc_now() - if ensure_naive_utc(form.expiration_time) <= now: - raise FormExpiredError(form.id) - if self._is_globally_expired(form, now=now): - raise FormExpiredError(form.id) - - def _ensure_not_submitted(self, form: Form) -> None: - if form.submitted: - raise FormSubmittedError(form.id) - - def _validate_submission(self, form: Form, selected_action_id: str, form_data: Mapping[str, Any]) -> None: - definition = form.get_definition() - try: - validate_human_input_submission( - inputs=definition.inputs, - user_actions=definition.user_actions, - selected_action_id=selected_action_id, - form_data=form_data, - ) - except HumanInputSubmissionValidationError as exc: - raise InvalidFormDataError(str(exc)) from exc - - def enqueue_resume(self, workflow_run_id: str) -> None: - workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory) - workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(workflow_run_id) - - if workflow_run is None: - raise AssertionError(f"WorkflowRun not found, id={workflow_run_id}") - with self._session_factory(expire_on_commit=False) as session: - app_query = select(App).where(App.id == workflow_run.app_id) - app = session.execute(app_query).scalar_one_or_none() - if app is None: - logger.error( - "App not found for WorkflowRun, workflow_run_id=%s, app_id=%s", workflow_run_id, workflow_run.app_id - ) - return - - if app.mode in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}: - payload = {"workflow_run_id": workflow_run_id} - try: - resume_app_execution.apply_async( - kwargs={"payload": payload}, - queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, - ) - except Exception: # pragma: no cover - logger.exception("Failed to enqueue resume task for workflow run %s", workflow_run_id) - return - - logger.warning("App mode %s does not support resume for workflow run %s", app.mode, workflow_run_id) - - def _is_globally_expired(self, form: Form, *, now: datetime | None = None) -> bool: - global_timeout_seconds = dify_config.HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS - if global_timeout_seconds <= 0: - return False - if form.workflow_run_id is None: - return False - current = now or naive_utc_now() - created_at = ensure_naive_utc(form.created_at) - global_deadline = created_at + timedelta(seconds=global_timeout_seconds) - return global_deadline <= current diff --git a/api/services/message_service.py b/api/services/message_service.py index ce699e79d4..a53ca8b22d 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,9 +1,6 @@ import json -from collections.abc import Sequence from typing import Union -from sqlalchemy.orm import sessionmaker - from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator @@ -17,10 +14,6 @@ from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback -from repositories.execution_extra_content_repository import ExecutionExtraContentRepository -from repositories.sqlalchemy_execution_extra_content_repository import ( - SQLAlchemyExecutionExtraContentRepository, -) from services.conversation_service import ConversationService from services.errors.message import ( FirstMessageNotExistsError, @@ -31,23 +24,6 @@ from services.errors.message import ( from services.workflow_service import WorkflowService -def _create_execution_extra_content_repository() -> ExecutionExtraContentRepository: - session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) - return SQLAlchemyExecutionExtraContentRepository(session_maker=session_maker) - - -def attach_message_extra_contents(messages: Sequence[Message]) -> None: - if not messages: - return - - repository = _create_execution_extra_content_repository() - extra_contents_lists = repository.get_by_message_ids([message.id for message in messages]) - - for index, message in enumerate(messages): - contents = extra_contents_lists[index] if index < len(extra_contents_lists) else [] - message.set_extra_contents([content.model_dump(mode="json", exclude_none=True) for content in contents]) - - class MessageService: @classmethod def pagination_by_first_id( @@ -109,8 +85,6 @@ class MessageService: if order == "asc": history_messages = list(reversed(history_messages)) - attach_message_extra_contents(history_messages) - return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) @classmethod diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 0ae40199ab..ab5d5480df 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -67,8 +67,6 @@ class WorkflowToolManageService: if workflow is None: raise ValueError(f"Workflow not found for app {workflow_app_id}") - WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict) - workflow_tool_provider = WorkflowToolProvider( tenant_id=tenant_id, user_id=user_id, @@ -160,8 +158,6 @@ class WorkflowToolManageService: if workflow is None: raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") - WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict) - workflow_tool_provider.name = name workflow_tool_provider.label = label workflow_tool_provider.icon = json.dumps(icon) diff --git a/api/services/workflow/entities.py b/api/services/workflow/entities.py index 2af0d1fd90..70ec8d6e2a 100644 --- a/api/services/workflow/entities.py +++ b/api/services/workflow/entities.py @@ -98,12 +98,6 @@ class WorkflowTaskData(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) -class WorkflowResumeTaskData(BaseModel): - """Payload for workflow resumption tasks.""" - - workflow_run_id: str - - class AsyncTriggerExecutionResult(BaseModel): """Result from async trigger-based workflow execution""" diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py deleted file mode 100644 index dd4651f130..0000000000 --- a/api/services/workflow_event_snapshot_service.py +++ /dev/null @@ -1,460 +0,0 @@ -from __future__ import annotations - -import json -import logging -import queue -import threading -import time -from collections.abc import Generator, Mapping, Sequence -from dataclasses import dataclass -from typing import Any - -from sqlalchemy import desc, select -from sqlalchemy.orm import Session, sessionmaker - -from core.app.apps.message_generator import MessageGenerator -from core.app.entities.task_entities import ( - MessageReplaceStreamResponse, - NodeFinishStreamResponse, - NodeStartStreamResponse, - StreamEvent, - WorkflowPauseStreamResponse, - WorkflowStartStreamResponse, -) -from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from core.workflow.entities import WorkflowStartReason -from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from core.workflow.runtime import GraphRuntimeState -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter -from models.model import AppMode, Message -from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun -from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot -from repositories.entities.workflow_pause import WorkflowPauseEntity -from repositories.factory import DifyAPIRepositoryFactory - -logger = logging.getLogger(__name__) - - -@dataclass(frozen=True) -class MessageContext: - conversation_id: str - message_id: str - created_at: int - answer: str | None = None - - -@dataclass -class BufferState: - queue: queue.Queue[Mapping[str, Any]] - stop_event: threading.Event - done_event: threading.Event - task_id_ready: threading.Event - task_id_hint: str | None = None - - -def build_workflow_event_stream( - *, - app_mode: AppMode, - workflow_run: WorkflowRun, - tenant_id: str, - app_id: str, - session_maker: sessionmaker[Session], - idle_timeout: float = 300, - ping_interval: float = 10.0, -) -> Generator[Mapping[str, Any] | str, None, None]: - topic = MessageGenerator.get_response_topic(app_mode, workflow_run.id) - workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) - node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) - message_context = ( - _get_message_context(session_maker, workflow_run.id) if app_mode == AppMode.ADVANCED_CHAT else None - ) - - pause_entity: WorkflowPauseEntity | None = None - if workflow_run.status == WorkflowExecutionStatus.PAUSED: - try: - pause_entity = workflow_run_repo.get_workflow_pause(workflow_run.id) - except Exception: - logger.exception("Failed to load workflow pause for run %s", workflow_run.id) - pause_entity = None - - resumption_context = _load_resumption_context(pause_entity) - node_snapshots = node_execution_repo.get_execution_snapshots_by_workflow_run( - tenant_id=tenant_id, - app_id=app_id, - workflow_id=workflow_run.workflow_id, - # NOTE(QuantumGhost): for events resumption, we only care about - # the execution records from `WORKFLOW_RUN`. - # - # Ideally filtering with `workflow_run_id` is enough. However, - # due to the index of `WorkflowNodeExecution` table, we have to - # add a filter condition of `triggered_from` to - # ensure that we can utilize the index. - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - workflow_run_id=workflow_run.id, - ) - - def _generate() -> Generator[Mapping[str, Any] | str, None, None]: - # send a PING event immediately to prevent the connection staying in pending state for a long time. - # - # This simplify the debugging process as the DevTools in Chrome does not - # provide complete curl command for pending connections. - yield StreamEvent.PING.value - - last_msg_time = time.time() - last_ping_time = last_msg_time - - with topic.subscribe() as sub: - buffer_state = _start_buffering(sub) - try: - task_id = _resolve_task_id(resumption_context, buffer_state, workflow_run.id) - - snapshot_events = _build_snapshot_events( - workflow_run=workflow_run, - node_snapshots=node_snapshots, - task_id=task_id, - message_context=message_context, - pause_entity=pause_entity, - resumption_context=resumption_context, - ) - - for event in snapshot_events: - last_msg_time = time.time() - last_ping_time = last_msg_time - yield event - if _is_terminal_event(event, include_paused=True): - return - - while True: - if buffer_state.done_event.is_set() and buffer_state.queue.empty(): - return - - try: - event = buffer_state.queue.get(timeout=0.1) - except queue.Empty: - current_time = time.time() - if current_time - last_msg_time > idle_timeout: - logger.debug( - "No workflow events received for %s seconds, keeping stream open", - idle_timeout, - ) - last_msg_time = current_time - if current_time - last_ping_time >= ping_interval: - yield StreamEvent.PING.value - last_ping_time = current_time - continue - - last_msg_time = time.time() - last_ping_time = last_msg_time - yield event - if _is_terminal_event(event, include_paused=True): - return - finally: - buffer_state.stop_event.set() - - return _generate() - - -def _get_message_context(session_maker: sessionmaker[Session], workflow_run_id: str) -> MessageContext | None: - with session_maker() as session: - stmt = select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(desc(Message.created_at)) - message = session.scalar(stmt) - if message is None: - return None - created_at = int(message.created_at.timestamp()) if message.created_at else 0 - return MessageContext( - conversation_id=message.conversation_id, - message_id=message.id, - created_at=created_at, - answer=message.answer, - ) - - -def _load_resumption_context(pause_entity: WorkflowPauseEntity | None) -> WorkflowResumptionContext | None: - if pause_entity is None: - return None - try: - raw_state = pause_entity.get_state().decode() - return WorkflowResumptionContext.loads(raw_state) - except Exception: - logger.exception("Failed to load resumption context") - return None - - -def _resolve_task_id( - resumption_context: WorkflowResumptionContext | None, - buffer_state: BufferState | None, - workflow_run_id: str, - wait_timeout: float = 0.2, -) -> str: - if resumption_context is not None: - generate_entity = resumption_context.get_generate_entity() - if generate_entity.task_id: - return generate_entity.task_id - if buffer_state is None: - return workflow_run_id - if buffer_state.task_id_hint is None: - buffer_state.task_id_ready.wait(timeout=wait_timeout) - if buffer_state.task_id_hint: - return buffer_state.task_id_hint - return workflow_run_id - - -def _build_snapshot_events( - *, - workflow_run: WorkflowRun, - node_snapshots: Sequence[WorkflowNodeExecutionSnapshot], - task_id: str, - message_context: MessageContext | None, - pause_entity: WorkflowPauseEntity | None, - resumption_context: WorkflowResumptionContext | None, -) -> list[Mapping[str, Any]]: - events: list[Mapping[str, Any]] = [] - - workflow_started = _build_workflow_started_event( - workflow_run=workflow_run, - task_id=task_id, - ) - _apply_message_context(workflow_started, message_context) - events.append(workflow_started) - - if message_context is not None and message_context.answer is not None: - message_replace = _build_message_replace_event(task_id=task_id, answer=message_context.answer) - _apply_message_context(message_replace, message_context) - events.append(message_replace) - - for snapshot in node_snapshots: - node_started = _build_node_started_event( - workflow_run_id=workflow_run.id, - task_id=task_id, - snapshot=snapshot, - ) - _apply_message_context(node_started, message_context) - events.append(node_started) - - if snapshot.status != WorkflowNodeExecutionStatus.RUNNING.value: - node_finished = _build_node_finished_event( - workflow_run_id=workflow_run.id, - task_id=task_id, - snapshot=snapshot, - ) - _apply_message_context(node_finished, message_context) - events.append(node_finished) - - if workflow_run.status == WorkflowExecutionStatus.PAUSED and pause_entity is not None: - pause_event = _build_pause_event( - workflow_run=workflow_run, - workflow_run_id=workflow_run.id, - task_id=task_id, - pause_entity=pause_entity, - resumption_context=resumption_context, - ) - if pause_event is not None: - _apply_message_context(pause_event, message_context) - events.append(pause_event) - - return events - - -def _build_workflow_started_event( - *, - workflow_run: WorkflowRun, - task_id: str, -) -> dict[str, Any]: - response = WorkflowStartStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run.id, - data=WorkflowStartStreamResponse.Data( - id=workflow_run.id, - workflow_id=workflow_run.workflow_id, - inputs=workflow_run.inputs_dict or {}, - created_at=int(workflow_run.created_at.timestamp()), - reason=WorkflowStartReason.INITIAL, - ), - ) - payload = response.model_dump(mode="json") - payload["event"] = response.event.value - return payload - - -def _build_message_replace_event(*, task_id: str, answer: str) -> dict[str, Any]: - response = MessageReplaceStreamResponse( - task_id=task_id, - answer=answer, - reason="", - ) - payload = response.model_dump(mode="json") - payload["event"] = response.event.value - return payload - - -def _build_node_started_event( - *, - workflow_run_id: str, - task_id: str, - snapshot: WorkflowNodeExecutionSnapshot, -) -> dict[str, Any]: - created_at = int(snapshot.created_at.timestamp()) if snapshot.created_at else 0 - response = NodeStartStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run_id, - data=NodeStartStreamResponse.Data( - id=snapshot.execution_id, - node_id=snapshot.node_id, - node_type=snapshot.node_type, - title=snapshot.title, - index=snapshot.index, - predecessor_node_id=None, - inputs=None, - created_at=created_at, - extras={}, - iteration_id=snapshot.iteration_id, - loop_id=snapshot.loop_id, - ), - ) - return response.to_ignore_detail_dict() - - -def _build_node_finished_event( - *, - workflow_run_id: str, - task_id: str, - snapshot: WorkflowNodeExecutionSnapshot, -) -> dict[str, Any]: - created_at = int(snapshot.created_at.timestamp()) if snapshot.created_at else 0 - finished_at = int(snapshot.finished_at.timestamp()) if snapshot.finished_at else created_at - response = NodeFinishStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run_id, - data=NodeFinishStreamResponse.Data( - id=snapshot.execution_id, - node_id=snapshot.node_id, - node_type=snapshot.node_type, - title=snapshot.title, - index=snapshot.index, - predecessor_node_id=None, - inputs=None, - process_data=None, - outputs=None, - status=snapshot.status, - error=None, - elapsed_time=snapshot.elapsed_time, - execution_metadata=None, - created_at=created_at, - finished_at=finished_at, - files=[], - iteration_id=snapshot.iteration_id, - loop_id=snapshot.loop_id, - ), - ) - return response.to_ignore_detail_dict() - - -def _build_pause_event( - *, - workflow_run: WorkflowRun, - workflow_run_id: str, - task_id: str, - pause_entity: WorkflowPauseEntity, - resumption_context: WorkflowResumptionContext | None, -) -> dict[str, Any] | None: - paused_nodes: list[str] = [] - outputs: dict[str, Any] = {} - if resumption_context is not None: - state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state) - paused_nodes = state.get_paused_nodes() - outputs = dict(WorkflowRuntimeTypeConverter().to_json_encodable(state.outputs or {})) - - reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()] - response = WorkflowPauseStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run_id, - data=WorkflowPauseStreamResponse.Data( - workflow_run_id=workflow_run_id, - paused_nodes=paused_nodes, - outputs=outputs, - reasons=reasons, - status=workflow_run.status.value, - created_at=int(workflow_run.created_at.timestamp()), - elapsed_time=float(workflow_run.elapsed_time or 0.0), - total_tokens=int(workflow_run.total_tokens or 0), - total_steps=int(workflow_run.total_steps or 0), - ), - ) - payload = response.model_dump(mode="json") - payload["event"] = response.event.value - return payload - - -def _apply_message_context(payload: dict[str, Any], message_context: MessageContext | None) -> None: - if message_context is None: - return - payload["conversation_id"] = message_context.conversation_id - payload["message_id"] = message_context.message_id - payload["created_at"] = message_context.created_at - - -def _start_buffering(subscription) -> BufferState: - buffer_state = BufferState( - queue=queue.Queue(maxsize=2048), - stop_event=threading.Event(), - done_event=threading.Event(), - task_id_ready=threading.Event(), - ) - - def _worker() -> None: - dropped_count = 0 - try: - while not buffer_state.stop_event.is_set(): - msg = subscription.receive(timeout=0.1) - if msg is None: - continue - event = _parse_event_message(msg) - if event is None: - continue - task_id = event.get("task_id") - if task_id and buffer_state.task_id_hint is None: - buffer_state.task_id_hint = str(task_id) - buffer_state.task_id_ready.set() - try: - buffer_state.queue.put_nowait(event) - except queue.Full: - dropped_count += 1 - try: - buffer_state.queue.get_nowait() - except queue.Empty: - pass - try: - buffer_state.queue.put_nowait(event) - except queue.Full: - continue - logger.warning("Dropped buffered workflow event, total_dropped=%s", dropped_count) - except Exception: - logger.exception("Failed while buffering workflow events") - finally: - buffer_state.done_event.set() - - thread = threading.Thread(target=_worker, name=f"workflow-event-buffer-{id(subscription)}", daemon=True) - thread.start() - return buffer_state - - -def _parse_event_message(message: bytes) -> Mapping[str, Any] | None: - try: - event = json.loads(message) - except json.JSONDecodeError: - logger.warning("Failed to decode workflow event payload") - return None - if not isinstance(event, dict): - return None - return event - - -def _is_terminal_event(event: Mapping[str, Any] | str, include_paused=False) -> bool: - if not isinstance(event, Mapping): - return False - event_type = event.get("event") - if event_type == StreamEvent.WORKFLOW_FINISHED.value: - return True - if include_paused: - return event_type == StreamEvent.WORKFLOW_PAUSED.value - return False diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 4e1e515de5..6404136994 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,5 +1,4 @@ import json -import logging import time import uuid from collections.abc import Callable, Generator, Mapping, Sequence @@ -12,34 +11,21 @@ from configs import dify_config from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.app.entities.app_invoke_entities import InvokeFrom from core.file import File from core.repositories import DifyCoreRepositoryFactory -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.variables import VariableBase from core.variables.variables import Variable -from core.workflow.entities import GraphInitParams, WorkflowNodeExecution -from core.workflow.entities.pause_reason import HumanInputRequired +from core.workflow.entities import WorkflowNodeExecution from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent from core.workflow.node_events import NodeRunResult from core.workflow.nodes import NodeType from core.workflow.nodes.base.node import Node -from core.workflow.nodes.human_input.entities import ( - DeliveryChannelConfig, - HumanInputNodeData, - apply_debug_email_recipient, - validate_human_input_submission, -) -from core.workflow.nodes.human_input.enums import HumanInputFormKind -from core.workflow.nodes.human_input.human_input_node import HumanInputNode from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.repositories.human_input_form_repository import FormCreateParams -from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable -from core.workflow.variable_loader import load_into_variable_pool from core.workflow.workflow_entry import WorkflowEntry from enums.cloud_plan import CloudPlan from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated @@ -48,8 +34,6 @@ from extensions.ext_storage import storage from factories.file_factory import build_from_mapping, build_from_mappings from libs.datetime_utils import naive_utc_now from models import Account -from models.enums import UserFrom -from models.human_input import HumanInputFormRecipient, RecipientType from models.model import App, AppMode from models.tools import WorkflowToolProvider from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType @@ -60,13 +44,6 @@ from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededEr from services.workflow.workflow_converter import WorkflowConverter from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError -from .human_input_delivery_test_service import ( - DeliveryTestContext, - DeliveryTestEmailRecipient, - DeliveryTestError, - DeliveryTestUnsupportedError, - HumanInputDeliveryTestService, -) from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService @@ -767,344 +744,6 @@ class WorkflowService: return workflow_node_execution - def get_human_input_form_preview( - self, - *, - app_model: App, - account: Account, - node_id: str, - inputs: Mapping[str, Any] | None = None, - ) -> Mapping[str, Any]: - """ - Build a human input form preview for a draft workflow. - - Args: - app_model: Target application model. - account: Current account. - node_id: Human input node ID. - inputs: Values used to fill missing upstream variables referenced in form_content. - """ - draft_workflow = self.get_draft_workflow(app_model=app_model) - if not draft_workflow: - raise ValueError("Workflow not initialized") - - node_config = draft_workflow.get_node_config_by_id(node_id) - node_type = Workflow.get_node_type_from_node_config(node_config) - if node_type is not NodeType.HUMAN_INPUT: - raise ValueError("Node type must be human-input.") - - # inputs: values used to fill missing upstream variables referenced in form_content. - variable_pool = self._build_human_input_variable_pool( - app_model=app_model, - workflow=draft_workflow, - node_config=node_config, - manual_inputs=inputs or {}, - ) - node = self._build_human_input_node( - workflow=draft_workflow, - account=account, - node_config=node_config, - variable_pool=variable_pool, - ) - - rendered_content = node.render_form_content_before_submission() - resolved_default_values = node.resolve_default_values() - node_data = node.node_data - human_input_required = HumanInputRequired( - form_id=node_id, - form_content=rendered_content, - inputs=node_data.inputs, - actions=node_data.user_actions, - node_id=node_id, - node_title=node.title, - resolved_default_values=resolved_default_values, - form_token=None, - ) - return human_input_required.model_dump(mode="json") - - def submit_human_input_form_preview( - self, - *, - app_model: App, - account: Account, - node_id: str, - form_inputs: Mapping[str, Any], - inputs: Mapping[str, Any] | None = None, - action: str, - ) -> Mapping[str, Any]: - """ - Submit a human input form preview for a draft workflow. - - Args: - app_model: Target application model. - account: Current account. - node_id: Human input node ID. - form_inputs: Values the user provides for the form's own fields. - inputs: Values used to fill missing upstream variables referenced in form_content. - action: Selected action ID. - """ - draft_workflow = self.get_draft_workflow(app_model=app_model) - if not draft_workflow: - raise ValueError("Workflow not initialized") - - node_config = draft_workflow.get_node_config_by_id(node_id) - node_type = Workflow.get_node_type_from_node_config(node_config) - if node_type is not NodeType.HUMAN_INPUT: - raise ValueError("Node type must be human-input.") - - # inputs: values used to fill missing upstream variables referenced in form_content. - # form_inputs: values the user provides for the form's own fields. - variable_pool = self._build_human_input_variable_pool( - app_model=app_model, - workflow=draft_workflow, - node_config=node_config, - manual_inputs=inputs or {}, - ) - node = self._build_human_input_node( - workflow=draft_workflow, - account=account, - node_config=node_config, - variable_pool=variable_pool, - ) - node_data = node.node_data - - validate_human_input_submission( - inputs=node_data.inputs, - user_actions=node_data.user_actions, - selected_action_id=action, - form_data=form_inputs, - ) - - rendered_content = node.render_form_content_before_submission() - outputs: dict[str, Any] = dict(form_inputs) - outputs["__action_id"] = action - outputs["__rendered_content"] = node.render_form_content_with_outputs( - rendered_content, outputs, node_data.outputs_field_names() - ) - - enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) - enclosing_node_id = enclosing_node_type_and_id[1] if enclosing_node_type_and_id else None - with Session(bind=db.engine) as session, session.begin(): - draft_var_saver = DraftVariableSaver( - session=session, - app_id=app_model.id, - node_id=node_id, - node_type=NodeType.HUMAN_INPUT, - node_execution_id=str(uuid.uuid4()), - user=account, - enclosing_node_id=enclosing_node_id, - ) - draft_var_saver.save(outputs=outputs, process_data={}) - session.commit() - - return outputs - - def test_human_input_delivery( - self, - *, - app_model: App, - account: Account, - node_id: str, - delivery_method_id: str, - inputs: Mapping[str, Any] | None = None, - ) -> None: - draft_workflow = self.get_draft_workflow(app_model=app_model) - if not draft_workflow: - raise ValueError("Workflow not initialized") - - node_config = draft_workflow.get_node_config_by_id(node_id) - node_type = Workflow.get_node_type_from_node_config(node_config) - if node_type is not NodeType.HUMAN_INPUT: - raise ValueError("Node type must be human-input.") - - node_data = HumanInputNodeData.model_validate(node_config.get("data", {})) - delivery_method = self._resolve_human_input_delivery_method( - node_data=node_data, - delivery_method_id=delivery_method_id, - ) - if delivery_method is None: - raise ValueError("Delivery method not found.") - delivery_method = apply_debug_email_recipient( - delivery_method, - enabled=True, - user_id=account.id or "", - ) - - variable_pool = self._build_human_input_variable_pool( - app_model=app_model, - workflow=draft_workflow, - node_config=node_config, - manual_inputs=inputs or {}, - ) - node = self._build_human_input_node( - workflow=draft_workflow, - account=account, - node_config=node_config, - variable_pool=variable_pool, - ) - rendered_content = node.render_form_content_before_submission() - resolved_default_values = node.resolve_default_values() - form_id, recipients = self._create_human_input_delivery_test_form( - app_model=app_model, - node_id=node_id, - node_data=node_data, - delivery_method=delivery_method, - rendered_content=rendered_content, - resolved_default_values=resolved_default_values, - ) - test_service = HumanInputDeliveryTestService() - context = DeliveryTestContext( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - node_id=node_id, - node_title=node_data.title, - rendered_content=rendered_content, - template_vars={"form_id": form_id}, - recipients=recipients, - variable_pool=variable_pool, - ) - try: - test_service.send_test(context=context, method=delivery_method) - except DeliveryTestUnsupportedError as exc: - raise ValueError("Delivery method does not support test send.") from exc - except DeliveryTestError as exc: - raise ValueError(str(exc)) from exc - - @staticmethod - def _resolve_human_input_delivery_method( - *, - node_data: HumanInputNodeData, - delivery_method_id: str, - ) -> DeliveryChannelConfig | None: - for method in node_data.delivery_methods: - if str(method.id) == delivery_method_id: - return method - return None - - def _create_human_input_delivery_test_form( - self, - *, - app_model: App, - node_id: str, - node_data: HumanInputNodeData, - delivery_method: DeliveryChannelConfig, - rendered_content: str, - resolved_default_values: Mapping[str, Any], - ) -> tuple[str, list[DeliveryTestEmailRecipient]]: - repo = HumanInputFormRepositoryImpl(session_factory=db.engine, tenant_id=app_model.tenant_id) - params = FormCreateParams( - app_id=app_model.id, - workflow_execution_id=None, - node_id=node_id, - form_config=node_data, - rendered_content=rendered_content, - delivery_methods=[delivery_method], - display_in_ui=False, - resolved_default_values=resolved_default_values, - form_kind=HumanInputFormKind.DELIVERY_TEST, - ) - form_entity = repo.create_form(params) - return form_entity.id, self._load_email_recipients(form_entity.id) - - @staticmethod - def _load_email_recipients(form_id: str) -> list[DeliveryTestEmailRecipient]: - logger = logging.getLogger(__name__) - - with Session(bind=db.engine) as session: - recipients = session.scalars( - select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_id) - ).all() - recipients_data: list[DeliveryTestEmailRecipient] = [] - for recipient in recipients: - if recipient.recipient_type not in {RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL}: - continue - if not recipient.access_token: - continue - try: - payload = json.loads(recipient.recipient_payload) - except Exception: - logger.exception("Failed to parse human input recipient payload for delivery test.") - continue - email = payload.get("email") - if email: - recipients_data.append(DeliveryTestEmailRecipient(email=email, form_token=recipient.access_token)) - return recipients_data - - def _build_human_input_node( - self, - *, - workflow: Workflow, - account: Account, - node_config: Mapping[str, Any], - variable_pool: VariablePool, - ) -> HumanInputNode: - graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - graph_config=workflow.graph_dict, - user_id=account.id, - user_from=UserFrom.ACCOUNT.value, - invoke_from=InvokeFrom.DEBUGGER.value, - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=time.perf_counter(), - ) - node = HumanInputNode( - id=node_config.get("id", str(uuid.uuid4())), - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - return node - - def _build_human_input_variable_pool( - self, - *, - app_model: App, - workflow: Workflow, - node_config: Mapping[str, Any], - manual_inputs: Mapping[str, Any], - ) -> VariablePool: - with Session(bind=db.engine, expire_on_commit=False) as session, session.begin(): - draft_var_srv = WorkflowDraftVariableService(session) - draft_var_srv.prefill_conversation_variable_default_values(workflow) - - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - environment_variables=workflow.environment_variables, - conversation_variables=[], - ) - - variable_loader = DraftVarLoader( - engine=db.engine, - app_id=app_model.id, - tenant_id=app_model.tenant_id, - ) - variable_mapping = HumanInputNode.extract_variable_selector_to_variable_mapping( - graph_config=workflow.graph_dict, - config=node_config, - ) - normalized_user_inputs: dict[str, Any] = dict(manual_inputs) - - load_into_variable_pool( - variable_loader=variable_loader, - variable_pool=variable_pool, - variable_mapping=variable_mapping, - user_inputs=normalized_user_inputs, - ) - WorkflowEntry.mapping_user_inputs_to_variable_pool( - variable_mapping=variable_mapping, - user_inputs=normalized_user_inputs, - variable_pool=variable_pool, - tenant_id=app_model.tenant_id, - ) - - return variable_pool - def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] ) -> WorkflowNodeExecution: @@ -1306,13 +945,6 @@ class WorkflowService: if any(nt.is_trigger_node for nt in node_types): raise ValueError("Start node and trigger nodes cannot coexist in the same workflow") - for node in node_configs: - node_data = node.get("data", {}) - node_type = node_data.get("type") - - if node_type == NodeType.HUMAN_INPUT: - self._validate_human_input_node_data(node_data) - def validate_features_structure(self, app_model: App, features: dict): if app_model.mode == AppMode.ADVANCED_CHAT: return AdvancedChatAppConfigManager.config_validate( @@ -1325,23 +957,6 @@ class WorkflowService: else: raise ValueError(f"Invalid app mode: {app_model.mode}") - def _validate_human_input_node_data(self, node_data: dict) -> None: - """ - Validate HumanInput node data format. - - Args: - node_data: The node data dictionary - - Raises: - ValueError: If the node data format is invalid - """ - from core.workflow.nodes.human_input.entities import HumanInputNodeData - - try: - HumanInputNodeData.model_validate(node_data) - except Exception as e: - raise ValueError(f"Invalid HumanInput node data: {str(e)}") - def update_workflow( self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict ) -> Workflow | None: diff --git a/api/tasks/app_generate/__init__.py b/api/tasks/app_generate/__init__.py deleted file mode 100644 index 4aa02ef39f..0000000000 --- a/api/tasks/app_generate/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .workflow_execute_task import AppExecutionParams, resume_app_execution, workflow_based_app_execution_task - -__all__ = ["AppExecutionParams", "resume_app_execution", "workflow_based_app_execution_task"] diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py deleted file mode 100644 index e58d334f41..0000000000 --- a/api/tasks/app_generate/workflow_execute_task.py +++ /dev/null @@ -1,491 +0,0 @@ -import contextlib -import logging -import uuid -from collections.abc import Generator, Mapping -from enum import StrEnum -from typing import Annotated, Any, TypeAlias, Union - -from celery import shared_task -from flask import current_app, json -from pydantic import BaseModel, Discriminator, Field, Tag -from sqlalchemy import Engine, select -from sqlalchemy.orm import Session, sessionmaker - -from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator -from core.app.apps.message_based_app_generator import MessageBasedAppGenerator -from core.app.apps.workflow.app_generator import WorkflowAppGenerator -from core.app.entities.app_invoke_entities import ( - AdvancedChatAppGenerateEntity, - InvokeFrom, - WorkflowAppGenerateEntity, -) -from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext -from core.repositories import DifyCoreRepositoryFactory -from core.workflow.runtime import GraphRuntimeState -from extensions.ext_database import db -from libs.flask_utils import set_login_user -from models.account import Account -from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom -from models.model import App, AppMode, Conversation, EndUser, Message -from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun -from repositories.factory import DifyAPIRepositoryFactory - -logger = logging.getLogger(__name__) - -WORKFLOW_BASED_APP_EXECUTION_QUEUE = "workflow_based_app_execution" - - -class _UserType(StrEnum): - ACCOUNT = "account" - END_USER = "end_user" - - -class _Account(BaseModel): - TYPE: _UserType = _UserType.ACCOUNT - - user_id: str - - -class _EndUser(BaseModel): - TYPE: _UserType = _UserType.END_USER - end_user_id: str - - -def _get_user_type_descriminator(value: Any): - if isinstance(value, (_Account, _EndUser)): - return value.TYPE - elif isinstance(value, dict): - user_type_str = value.get("TYPE") - if user_type_str is None: - return None - try: - user_type = _UserType(user_type_str) - except ValueError: - return None - return user_type - else: - # return None if the discriminator value isn't found - return None - - -User: TypeAlias = Annotated[ - (Annotated[_Account, Tag(_UserType.ACCOUNT)] | Annotated[_EndUser, Tag(_UserType.END_USER)]), - Discriminator(_get_user_type_descriminator), -] - - -class AppExecutionParams(BaseModel): - app_id: str - workflow_id: str - tenant_id: str - app_mode: AppMode = AppMode.ADVANCED_CHAT - user: User - args: Mapping[str, Any] - - invoke_from: InvokeFrom - streaming: bool = True - call_depth: int = 0 - root_node_id: str | None = None - workflow_run_id: str = Field(default_factory=lambda: str(uuid.uuid4())) - - @classmethod - def new( - cls, - app_model: App, - workflow: Workflow, - user: Union[Account, EndUser], - args: Mapping[str, Any], - invoke_from: InvokeFrom, - streaming: bool = True, - call_depth: int = 0, - root_node_id: str | None = None, - workflow_run_id: str | None = None, - ): - user_params: _Account | _EndUser - if isinstance(user, Account): - user_params = _Account(user_id=user.id) - elif isinstance(user, EndUser): - user_params = _EndUser(end_user_id=user.id) - else: - raise AssertionError("this statement should be unreachable.") - return cls( - app_id=app_model.id, - workflow_id=workflow.id, - tenant_id=app_model.tenant_id, - app_mode=AppMode.value_of(app_model.mode), - user=user_params, - args=args, - invoke_from=invoke_from, - streaming=streaming, - call_depth=call_depth, - root_node_id=root_node_id, - workflow_run_id=workflow_run_id or str(uuid.uuid4()), - ) - - -class _AppRunner: - def __init__(self, session_factory: sessionmaker | Engine, exec_params: AppExecutionParams): - if isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory - self._exec_params = exec_params - - @contextlib.contextmanager - def _session(self): - with self._session_factory(expire_on_commit=False) as session, session.begin(): - yield session - - @contextlib.contextmanager - def _setup_flask_context(self, user: Account | EndUser): - flask_app = current_app._get_current_object() # type: ignore - with flask_app.app_context(): - set_login_user(user) - yield - - def run(self): - exec_params = self._exec_params - with self._session() as session: - workflow = session.get(Workflow, exec_params.workflow_id) - if workflow is None: - logger.warning("Workflow %s not found for execution", exec_params.workflow_id) - return None - app = session.get(App, workflow.app_id) - if app is None: - logger.warning("App %s not found for workflow %s", workflow.app_id, exec_params.workflow_id) - return None - - pause_config = PauseStateLayerConfig( - session_factory=self._session_factory, - state_owner_user_id=workflow.created_by, - ) - - user = self._resolve_user() - - with self._setup_flask_context(user): - response = self._run_app( - app=app, - workflow=workflow, - user=user, - pause_state_config=pause_config, - ) - if not exec_params.streaming: - return response - - assert isinstance(response, Generator) - _publish_streaming_response(response, exec_params.workflow_run_id, exec_params.app_mode) - - def _run_app( - self, - *, - app: App, - workflow: Workflow, - user: Account | EndUser, - pause_state_config: PauseStateLayerConfig, - ): - exec_params = self._exec_params - if exec_params.app_mode == AppMode.ADVANCED_CHAT: - return AdvancedChatAppGenerator().generate( - app_model=app, - workflow=workflow, - user=user, - args=exec_params.args, - invoke_from=exec_params.invoke_from, - streaming=exec_params.streaming, - workflow_run_id=exec_params.workflow_run_id, - pause_state_config=pause_state_config, - ) - if exec_params.app_mode == AppMode.WORKFLOW: - return WorkflowAppGenerator().generate( - app_model=app, - workflow=workflow, - user=user, - args=exec_params.args, - invoke_from=exec_params.invoke_from, - streaming=exec_params.streaming, - call_depth=exec_params.call_depth, - root_node_id=exec_params.root_node_id, - workflow_run_id=exec_params.workflow_run_id, - pause_state_config=pause_state_config, - ) - - logger.error("Unsupported app mode for execution: %s", exec_params.app_mode) - return None - - def _resolve_user(self) -> Account | EndUser: - user_params = self._exec_params.user - - if isinstance(user_params, _EndUser): - with self._session() as session: - return session.get(EndUser, user_params.end_user_id) - elif not isinstance(user_params, _Account): - raise AssertionError(f"user should only be _Account or _EndUser, got {type(user_params)}") - - with self._session() as session: - user: Account = session.get(Account, user_params.user_id) - user.set_tenant_id(self._exec_params.tenant_id) - - return user - - -def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Account | EndUser | None: - role = CreatorUserRole(workflow_run.created_by_role) - if role == CreatorUserRole.ACCOUNT: - user = session.get(Account, workflow_run.created_by) - if user: - user.set_tenant_id(workflow_run.tenant_id) - return user - - return session.get(EndUser, workflow_run.created_by) - - -def _publish_streaming_response( - response_stream: Generator[str | Mapping[str, Any], None, None], workflow_run_id: str, app_mode: AppMode -) -> None: - topic = MessageBasedAppGenerator.get_response_topic(app_mode, workflow_run_id) - for event in response_stream: - try: - payload = json.dumps(event) - except TypeError: - logger.exception("error while encoding event") - continue - - topic.publish(payload.encode()) - - -@shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE) -def workflow_based_app_execution_task( - payload: str, -) -> Generator[Mapping[str, Any] | str, None, None] | Mapping[str, Any] | None: - exec_params = AppExecutionParams.model_validate_json(payload) - - logger.info("workflow_based_app_execution_task run with params: %s", exec_params) - - runner = _AppRunner(db.engine, exec_params=exec_params) - return runner.run() - - -def _resume_app_execution(payload: dict[str, Any]) -> None: - workflow_run_id = payload["workflow_run_id"] - - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_factory) - - pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) - if pause_entity is None: - logger.warning("No pause entity found for workflow run %s", workflow_run_id) - return - - try: - resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode()) - except Exception: - logger.exception("Failed to load resumption context for workflow run %s", workflow_run_id) - return - - generate_entity = resumption_context.get_generate_entity() - - graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state) - - conversation = None - message = None - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = session.get(WorkflowRun, workflow_run_id) - if workflow_run is None: - logger.warning("Workflow run %s not found during resume", workflow_run_id) - return - - workflow = session.get(Workflow, workflow_run.workflow_id) - if workflow is None: - logger.warning("Workflow %s not found during resume", workflow_run.workflow_id) - return - - app_model = session.get(App, workflow_run.app_id) - if app_model is None: - logger.warning("App %s not found during resume", workflow_run.app_id) - return - - user = _resolve_user_for_run(session, workflow_run) - if user is None: - logger.warning("User %s not found for workflow run %s", workflow_run.created_by, workflow_run_id) - return - - if isinstance(generate_entity, AdvancedChatAppGenerateEntity): - if generate_entity.conversation_id is None: - logger.warning("Conversation id missing in resumption context for workflow run %s", workflow_run_id) - return - - conversation = session.get(Conversation, generate_entity.conversation_id) - if conversation is None: - logger.warning( - "Conversation %s not found for workflow run %s", generate_entity.conversation_id, workflow_run_id - ) - return - - message = session.scalar( - select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(Message.created_at.desc()) - ) - if message is None: - logger.warning("Message not found for workflow run %s", workflow_run_id) - return - - if not isinstance(generate_entity, (AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity)): - logger.error( - "Unsupported resumption entity for workflow run %s (found %s)", - workflow_run_id, - type(generate_entity), - ) - return - - workflow_run_repo.resume_workflow_pause(workflow_run_id, pause_entity) - - pause_config = PauseStateLayerConfig( - session_factory=session_factory, - state_owner_user_id=workflow.created_by, - ) - - if isinstance(generate_entity, AdvancedChatAppGenerateEntity): - assert conversation is not None - assert message is not None - _resume_advanced_chat( - app_model=app_model, - workflow=workflow, - user=user, - conversation=conversation, - message=message, - generate_entity=generate_entity, - graph_runtime_state=graph_runtime_state, - session_factory=session_factory, - pause_state_config=pause_config, - workflow_run_id=workflow_run_id, - workflow_run=workflow_run, - ) - elif isinstance(generate_entity, WorkflowAppGenerateEntity): - _resume_workflow( - app_model=app_model, - workflow=workflow, - user=user, - generate_entity=generate_entity, - graph_runtime_state=graph_runtime_state, - session_factory=session_factory, - pause_state_config=pause_config, - workflow_run_id=workflow_run_id, - workflow_run=workflow_run, - workflow_run_repo=workflow_run_repo, - pause_entity=pause_entity, - ) - - -def _resume_advanced_chat( - *, - app_model: App, - workflow: Workflow, - user: Account | EndUser, - conversation: Conversation, - message: Message, - generate_entity: AdvancedChatAppGenerateEntity, - graph_runtime_state: GraphRuntimeState, - session_factory: sessionmaker, - pause_state_config: PauseStateLayerConfig, - workflow_run_id: str, - workflow_run: WorkflowRun, -) -> None: - try: - triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from) - except ValueError: - triggered_from = WorkflowRunTriggeredFrom.APP_RUN - - workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=session_factory, - user=user, - app_id=app_model.id, - triggered_from=triggered_from, - ) - workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=session_factory, - user=user, - app_id=app_model.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - - generator = AdvancedChatAppGenerator() - - try: - response = generator.resume( - app_model=app_model, - workflow=workflow, - user=user, - conversation=conversation, - message=message, - application_generate_entity=generate_entity, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - graph_runtime_state=graph_runtime_state, - pause_state_config=pause_state_config, - ) - except Exception: - logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id) - raise - - if generate_entity.stream: - assert isinstance(response, Generator) - _publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT) - - -def _resume_workflow( - *, - app_model: App, - workflow: Workflow, - user: Account | EndUser, - generate_entity: WorkflowAppGenerateEntity, - graph_runtime_state: GraphRuntimeState, - session_factory: sessionmaker, - pause_state_config: PauseStateLayerConfig, - workflow_run_id: str, - workflow_run: WorkflowRun, - workflow_run_repo, - pause_entity, -) -> None: - try: - triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from) - except ValueError: - triggered_from = WorkflowRunTriggeredFrom.APP_RUN - - workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=session_factory, - user=user, - app_id=app_model.id, - triggered_from=triggered_from, - ) - workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=session_factory, - user=user, - app_id=app_model.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - - generator = WorkflowAppGenerator() - - try: - response = generator.resume( - app_model=app_model, - workflow=workflow, - user=user, - application_generate_entity=generate_entity, - graph_runtime_state=graph_runtime_state, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - pause_state_config=pause_state_config, - ) - except Exception: - logger.exception("Failed to resume workflow execution for workflow run %s", workflow_run_id) - raise - - if generate_entity.stream: - assert isinstance(response, Generator) - _publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW) - - workflow_run_repo.delete_workflow_pause(pause_entity) - - -@shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, name="resume_app_execution") -def resume_app_execution(payload: dict[str, Any]) -> None: - _resume_app_execution(payload) diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index cc96542d4b..b51884148e 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -5,42 +5,32 @@ These tasks handle workflow execution for different subscription tiers with appropriate retry policies and error handling. """ -import logging from datetime import UTC, datetime from typing import Any from celery import shared_task from sqlalchemy import select -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import Session from configs import dify_config from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator -from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext -from core.app.layers.timeslice_layer import TimeSliceLayer +from core.app.entities.app_invoke_entities import InvokeFrom from core.app.layers.trigger_post_layer import TriggerPostLayer from core.db.session_factory import session_factory -from core.repositories import DifyCoreRepositoryFactory -from core.workflow.runtime import GraphRuntimeState -from extensions.ext_database import db from models.account import Account -from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus +from models.enums import CreatorUserRole, WorkflowTriggerStatus from models.model import App, EndUser, Tenant from models.trigger import WorkflowTriggerLog -from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun -from repositories.factory import DifyAPIRepositoryFactory +from models.workflow import Workflow from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from services.errors.app import WorkflowNotFoundError from services.workflow.entities import ( TriggerData, - WorkflowResumeTaskData, WorkflowTaskData, ) from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity, AsyncWorkflowCFSPlanScheduler from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkflowSystemStrategy -logger = logging.getLogger(__name__) - @shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE) def execute_workflow_professional(task_data_dict: dict[str, Any]): @@ -151,11 +141,6 @@ def _execute_workflow_common( if trigger_data.workflow_id: args["workflow_id"] = str(trigger_data.workflow_id) - pause_config = PauseStateLayerConfig( - session_factory=session_factory.get_session_maker(), - state_owner_user_id=workflow.created_by, - ) - # Execute the workflow with the trigger type generator.generate( app_model=app_model, @@ -171,7 +156,6 @@ def _execute_workflow_common( # TODO: Re-enable TimeSliceLayer after the HITL release. TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id), ], - pause_state_config=pause_config, ) except Exception as e: @@ -189,153 +173,21 @@ def _execute_workflow_common( session.commit() -@shared_task(name="resume_workflow_execution") -def resume_workflow_execution(task_data_dict: dict[str, Any]) -> None: - """Resume a paused workflow run via Celery.""" - task_data = WorkflowResumeTaskData.model_validate(task_data_dict) - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory) - - pause_entity = workflow_run_repo.get_workflow_pause(task_data.workflow_run_id) - if pause_entity is None: - logger.warning("No pause state for workflow run %s", task_data.workflow_run_id) - return - workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(pause_entity.workflow_execution_id) - if workflow_run is None: - logger.warning("Workflow run not found for pause entity: pause_entity_id=%s", pause_entity.id) - return - - try: - resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode()) - except Exception as exc: - logger.exception("Failed to load resumption context for workflow run %s", task_data.workflow_run_id) - raise exc - - generate_entity = resumption_context.get_generate_entity() - if not isinstance(generate_entity, WorkflowAppGenerateEntity): - logger.error( - "Unsupported resumption entity for workflow run %s: %s", - task_data.workflow_run_id, - type(generate_entity), - ) - return - - graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state) - - with session_factory() as session: - workflow = session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id)) - if workflow is None: - raise WorkflowNotFoundError( - "Workflow not found: workflow_run_id=%s, workflow_id=%s", workflow_run.id, workflow_run.workflow_id - ) - user = _get_user(session, workflow_run) - app_model = session.scalar(select(App).where(App.id == workflow_run.app_id)) - if app_model is None: - raise _AppNotFoundError( - "App not found: app_id=%s, workflow_run_id=%s", workflow_run.app_id, workflow_run.id - ) - - workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=session_factory, - user=user, - app_id=generate_entity.app_config.app_id, - triggered_from=WorkflowRunTriggeredFrom(workflow_run.triggered_from), - ) - workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=session_factory, - user=user, - app_id=generate_entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - - pause_config = PauseStateLayerConfig( - session_factory=session_factory, - state_owner_user_id=workflow.created_by, - ) - - generator = WorkflowAppGenerator() - start_time = datetime.now(UTC) - graph_engine_layers = [] - trigger_log = _query_trigger_log_info(session_factory, task_data.workflow_run_id) - - if trigger_log: - cfs_plan_scheduler_entity = AsyncWorkflowCFSPlanEntity( - queue=AsyncWorkflowQueue(trigger_log.queue_name), - schedule_strategy=AsyncWorkflowSystemStrategy, - granularity=dify_config.ASYNC_WORKFLOW_SCHEDULER_GRANULARITY, - ) - cfs_plan_scheduler = AsyncWorkflowCFSPlanScheduler(plan=cfs_plan_scheduler_entity) - - graph_engine_layers.extend( - [ - TimeSliceLayer(cfs_plan_scheduler), - TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id), - ] - ) - - workflow_run_repo.resume_workflow_pause(task_data.workflow_run_id, pause_entity) - - generator.resume( - app_model=app_model, - workflow=workflow, - user=user, - application_generate_entity=generate_entity, - graph_runtime_state=graph_runtime_state, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - graph_engine_layers=graph_engine_layers, - pause_state_config=pause_config, - ) - workflow_run_repo.delete_workflow_pause(pause_entity) - - -def _get_user(session: Session, workflow_run: WorkflowRun | WorkflowTriggerLog) -> Account | EndUser: +def _get_user(session: Session, trigger_log: WorkflowTriggerLog) -> Account | EndUser: """Compose user from trigger log""" - tenant = session.scalar(select(Tenant).where(Tenant.id == workflow_run.tenant_id)) + tenant = session.scalar(select(Tenant).where(Tenant.id == trigger_log.tenant_id)) if not tenant: - raise _TenantNotFoundError( - "Tenant not found for WorkflowRun: tenant_id=%s, workflow_run_id=%s", - workflow_run.tenant_id, - workflow_run.id, - ) + raise ValueError(f"Tenant not found: {trigger_log.tenant_id}") # Get user from trigger log - if workflow_run.created_by_role == CreatorUserRole.ACCOUNT: - user = session.scalar(select(Account).where(Account.id == workflow_run.created_by)) + if trigger_log.created_by_role == CreatorUserRole.ACCOUNT: + user = session.scalar(select(Account).where(Account.id == trigger_log.created_by)) if user: user.current_tenant = tenant else: # CreatorUserRole.END_USER - user = session.scalar(select(EndUser).where(EndUser.id == workflow_run.created_by)) + user = session.scalar(select(EndUser).where(EndUser.id == trigger_log.created_by)) if not user: - raise _UserNotFoundError( - "User not found: user_id=%s, created_by_role=%s, workflow_run_id=%s", - workflow_run.created_by, - workflow_run.created_by_role, - workflow_run.id, - ) + raise ValueError(f"User not found: {trigger_log.created_by} (role: {trigger_log.created_by_role})") return user - - -def _query_trigger_log_info(session_factory: sessionmaker[Session], workflow_run_id) -> WorkflowTriggerLog | None: - with session_factory() as session, session.begin(): - trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) - trigger_log = trigger_log_repo.get_by_workflow_run_id(workflow_run_id) - if not trigger_log: - logger.debug("Trigger log not found for workflow_run: workflow_run_id=%s", workflow_run_id) - return None - - return trigger_log - - -class _TenantNotFoundError(Exception): - pass - - -class _UserNotFoundError(Exception): - pass - - -class _AppNotFoundError(Exception): - pass diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py deleted file mode 100644 index 5413a33d6a..0000000000 --- a/api/tasks/human_input_timeout_tasks.py +++ /dev/null @@ -1,113 +0,0 @@ -import logging -from datetime import timedelta - -from celery import shared_task -from sqlalchemy import or_, select -from sqlalchemy.orm import sessionmaker - -from configs import dify_config -from core.repositories.human_input_repository import HumanInputFormSubmissionRepository -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus -from extensions.ext_database import db -from extensions.ext_storage import storage -from libs.datetime_utils import ensure_naive_utc, naive_utc_now -from models.human_input import HumanInputForm -from models.workflow import WorkflowPause, WorkflowRun -from services.human_input_service import HumanInputService - -logger = logging.getLogger(__name__) - - -def _is_global_timeout(form_model: HumanInputForm, global_timeout_seconds: int, *, now) -> bool: - if global_timeout_seconds <= 0: - return False - if form_model.workflow_run_id is None: - return False - created_at = ensure_naive_utc(form_model.created_at) - global_deadline = created_at + timedelta(seconds=global_timeout_seconds) - return global_deadline <= now - - -def _handle_global_timeout(*, form_id: str, workflow_run_id: str, node_id: str, session_factory: sessionmaker) -> None: - now = naive_utc_now() - with session_factory() as session, session.begin(): - workflow_run = session.get(WorkflowRun, workflow_run_id) - if workflow_run is not None: - workflow_run.status = WorkflowExecutionStatus.STOPPED - workflow_run.error = f"Human input global timeout at node {node_id}" - workflow_run.finished_at = now - session.add(workflow_run) - - pause_model = session.scalar(select(WorkflowPause).where(WorkflowPause.workflow_run_id == workflow_run_id)) - if pause_model is not None: - try: - storage.delete(pause_model.state_object_key) - except Exception: - logger.exception( - "Failed to delete pause state object for workflow_run_id=%s, pause_id=%s", - workflow_run_id, - pause_model.id, - ) - pause_model.resumed_at = now - session.add(pause_model) - - -@shared_task(name="human_input_form_timeout.check_and_resume", queue="schedule_executor") -def check_and_handle_human_input_timeouts(limit: int = 100) -> None: - """Scan for expired human input forms and resume or end workflows.""" - - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - form_repo = HumanInputFormSubmissionRepository(session_factory) - service = HumanInputService(session_factory, form_repository=form_repo) - now = naive_utc_now() - global_timeout_seconds = dify_config.HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS - - with session_factory() as session: - global_deadline = now - timedelta(seconds=global_timeout_seconds) if global_timeout_seconds > 0 else None - timeout_filter = HumanInputForm.expiration_time <= now - if global_deadline is not None: - timeout_filter = or_(timeout_filter, HumanInputForm.created_at <= global_deadline) - stmt = ( - select(HumanInputForm) - .where( - HumanInputForm.status == HumanInputFormStatus.WAITING, - timeout_filter, - ) - .order_by(HumanInputForm.id.asc()) - .limit(limit) - ) - expired_forms = session.scalars(stmt).all() - - for form_model in expired_forms: - try: - if form_model.form_kind == HumanInputFormKind.DELIVERY_TEST: - form_repo.mark_timeout( - form_id=form_model.id, - timeout_status=HumanInputFormStatus.TIMEOUT, - reason="delivery_test_timeout", - ) - continue - - is_global = _is_global_timeout(form_model, global_timeout_seconds, now=now) - record = form_repo.mark_timeout( - form_id=form_model.id, - timeout_status=HumanInputFormStatus.EXPIRED if is_global else HumanInputFormStatus.TIMEOUT, - reason="global_timeout" if is_global else "node_timeout", - ) - assert record.workflow_run_id is not None, "workflow_run_id should not be None for non-test form" - if is_global: - _handle_global_timeout( - form_id=record.form_id, - workflow_run_id=record.workflow_run_id, - node_id=record.node_id, - session_factory=session_factory, - ) - else: - service.enqueue_resume(record.workflow_run_id) - except Exception: - logger.exception( - "Failed to handle timeout for form_id=%s workflow_run_id=%s", - form_model.id, - form_model.workflow_run_id, - ) diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py deleted file mode 100644 index d1cd0fbadc..0000000000 --- a/api/tasks/mail_human_input_delivery_task.py +++ /dev/null @@ -1,190 +0,0 @@ -import json -import logging -import time -from dataclasses import dataclass -from typing import Any - -import click -from celery import shared_task -from sqlalchemy import select -from sqlalchemy.orm import Session, sessionmaker - -from configs import dify_config -from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailDeliveryMethod -from core.workflow.runtime import GraphRuntimeState, VariablePool -from extensions.ext_database import db -from extensions.ext_mail import mail -from models.human_input import ( - DeliveryMethodType, - HumanInputDelivery, - HumanInputForm, - HumanInputFormRecipient, - RecipientType, -) -from repositories.factory import DifyAPIRepositoryFactory -from services.feature_service import FeatureService - -logger = logging.getLogger(__name__) - - -@dataclass(frozen=True) -class _EmailRecipient: - email: str - token: str - - -@dataclass(frozen=True) -class _EmailDeliveryJob: - form_id: str - subject: str - body: str - form_content: str - recipients: list[_EmailRecipient] - - -def _build_form_link(token: str) -> str: - base_url = dify_config.APP_WEB_URL - return f"{base_url.rstrip('/')}/form/{token}" - - -def _parse_recipient_payload(payload: str) -> tuple[str | None, RecipientType | None]: - try: - payload_dict: dict[str, Any] = json.loads(payload) - except Exception: - logger.exception("Failed to parse recipient payload") - return None, None - - return payload_dict.get("email"), payload_dict.get("TYPE") - - -def _load_email_jobs(session: Session, form: HumanInputForm) -> list[_EmailDeliveryJob]: - deliveries = session.scalars( - select(HumanInputDelivery).where( - HumanInputDelivery.form_id == form.id, - HumanInputDelivery.delivery_method_type == DeliveryMethodType.EMAIL, - ) - ).all() - jobs: list[_EmailDeliveryJob] = [] - for delivery in deliveries: - delivery_config = EmailDeliveryMethod.model_validate_json(delivery.channel_payload) - - recipients = session.scalars( - select(HumanInputFormRecipient).where(HumanInputFormRecipient.delivery_id == delivery.id) - ).all() - - recipient_entities: list[_EmailRecipient] = [] - for recipient in recipients: - email, recipient_type = _parse_recipient_payload(recipient.recipient_payload) - if recipient_type not in {RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL}: - continue - if not email: - continue - token = recipient.access_token - if not token: - continue - recipient_entities.append(_EmailRecipient(email=email, token=token)) - - if not recipient_entities: - continue - - jobs.append( - _EmailDeliveryJob( - form_id=form.id, - subject=delivery_config.config.subject, - body=delivery_config.config.body, - form_content=form.rendered_content, - recipients=recipient_entities, - ) - ) - return jobs - - -def _render_body( - body_template: str, - form_link: str, - *, - variable_pool: VariablePool | None, -) -> str: - body = EmailDeliveryConfig.render_body_template( - body=body_template, - url=form_link, - variable_pool=variable_pool, - ) - return body - - -def _load_variable_pool(workflow_run_id: str | None) -> VariablePool | None: - if not workflow_run_id: - return None - - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_factory) - pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) - if pause_entity is None: - logger.info("No pause state found for workflow run %s", workflow_run_id) - return None - - try: - resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode()) - except Exception: - logger.exception("Failed to load resumption context for workflow run %s", workflow_run_id) - return None - - graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state) - return graph_runtime_state.variable_pool - - -def _open_session(session_factory: sessionmaker | Session | None): - if session_factory is None: - return Session(db.engine) - if isinstance(session_factory, Session): - return session_factory - return session_factory() - - -@shared_task(queue="mail") -def dispatch_human_input_email_task(form_id: str, node_title: str | None = None, session_factory=None): - if not mail.is_inited(): - return - - logger.info(click.style(f"Start human input email delivery for form {form_id}", fg="green")) - start_at = time.perf_counter() - - try: - with _open_session(session_factory) as session: - form = session.get(HumanInputForm, form_id) - if form is None: - logger.warning("Human input form not found, form_id=%s", form_id) - return - features = FeatureService.get_features(form.tenant_id) - if not features.human_input_email_delivery_enabled: - logger.info( - "Human input email delivery is not available for tenant=%s, form_id=%s", - form.tenant_id, - form_id, - ) - return - jobs = _load_email_jobs(session, form) - - variable_pool = _load_variable_pool(form.workflow_run_id) - - for job in jobs: - for recipient in job.recipients: - form_link = _build_form_link(recipient.token) - body = _render_body(job.body, form_link, variable_pool=variable_pool) - - mail.send( - to=recipient.email, - subject=job.subject, - html=body, - ) - - end_at = time.perf_counter() - logger.info( - click.style( - f"Human input email delivery succeeded for form {form_id}: latency: {end_at - start_at}", fg="green" - ) - ) - except Exception: - logger.exception("Send human input email failed, form_id=%s", form_id) diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index 44adadeaa5..948cf8b3a0 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -1,4 +1,3 @@ -import logging import os import pathlib import random @@ -11,34 +10,26 @@ from flask.testing import FlaskClient from sqlalchemy.orm import Session from app_factory import create_app -from configs.app_config import DifyConfig from extensions.ext_database import db from models import Account, DifySetup, Tenant, TenantAccountJoin from services.account_service import AccountService, RegisterService -_DEFUALT_TEST_ENV = ".env" -_DEFAULT_VDB_TEST_ENV = "vdb.env" - -_logger = logging.getLogger(__name__) - # Loading the .env file if it exists def _load_env(): current_file_path = pathlib.Path(__file__).absolute() # Items later in the list have higher precedence. - env_file_paths = [ - os.getenv("DIFY_TEST_ENV_FILE", str(current_file_path.parent / _DEFUALT_TEST_ENV)), - os.getenv("DIFY_VDB_TEST_ENV_FILE", str(current_file_path.parent / _DEFAULT_VDB_TEST_ENV)), - ] + files_to_load = [".env", "vdb.env"] - for env_path_str in env_file_paths: - if not pathlib.Path(env_path_str).exists(): - _logger.warning("specified configuration file %s not exist", env_path_str) + env_file_paths = [current_file_path.parent / i for i in files_to_load] + for path in env_file_paths: + if not path.exists(): + continue from dotenv import load_dotenv # Set `override=True` to ensure values from `vdb.env` take priority over values from `.env` - load_dotenv(str(env_path_str), override=True) + load_dotenv(str(path), override=True) _load_env() @@ -50,12 +41,6 @@ os.environ.setdefault("OPENDAL_SCHEME", "fs") _CACHED_APP = create_app() -@pytest.fixture(scope="session") -def dify_config() -> DifyConfig: - config = DifyConfig() # type: ignore - return config - - @pytest.fixture def flask_app() -> Flask: return _CACHED_APP diff --git a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/__init__.py b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/__init__.py deleted file mode 100644 index e3f0d8a96e..0000000000 --- a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -Utilities and helpers for Redis broadcast channel integration tests. - -This module provides utility classes and functions for testing -Redis broadcast channel functionality. -""" - -from .test_data import ( - LARGE_MESSAGES, - SMALL_MESSAGES, - SPECIAL_MESSAGES, - BufferTestConfig, - ConcurrencyTestConfig, - ErrorTestConfig, -) -from .test_helpers import ( - ConcurrentPublisher, - SubscriptionMonitor, - assert_message_order, - measure_throughput, - wait_for_condition, -) - -__all__ = [ - "LARGE_MESSAGES", - "SMALL_MESSAGES", - "SPECIAL_MESSAGES", - "BufferTestConfig", - "ConcurrencyTestConfig", - "ConcurrentPublisher", - "ErrorTestConfig", - "SubscriptionMonitor", - "assert_message_order", - "measure_throughput", - "wait_for_condition", -] diff --git a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_data.py b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_data.py deleted file mode 100644 index 2cccb08304..0000000000 --- a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_data.py +++ /dev/null @@ -1,315 +0,0 @@ -""" -Test data and configuration classes for Redis broadcast channel integration tests. - -This module provides dataclasses and constants for test configurations, -message sets, and test scenarios. -""" - -import dataclasses -from typing import Any - -from libs.broadcast_channel.channel import Overflow - - -@dataclasses.dataclass(frozen=True) -class BufferTestConfig: - """Configuration for buffer management tests.""" - - buffer_size: int - overflow_strategy: Overflow - message_count: int - expected_behavior: str - description: str - - -@dataclasses.dataclass(frozen=True) -class ConcurrencyTestConfig: - """Configuration for concurrency tests.""" - - publisher_count: int - subscriber_count: int - messages_per_publisher: int - test_duration: float - description: str - - -@dataclasses.dataclass(frozen=True) -class ErrorTestConfig: - """Configuration for error handling tests.""" - - error_type: str - test_input: Any - expected_exception: type[Exception] - description: str - - -# Test message sets for different scenarios -SMALL_MESSAGES = [ - b"msg_1", - b"msg_2", - b"msg_3", - b"msg_4", - b"msg_5", -] - -MEDIUM_MESSAGES = [ - b"medium_message_1_with_more_content", - b"medium_message_2_with_more_content", - b"medium_message_3_with_more_content", - b"medium_message_4_with_more_content", - b"medium_message_5_with_more_content", -] - -LARGE_MESSAGES = [ - b"large_message_" + b"x" * 1000, - b"large_message_" + b"y" * 1000, - b"large_message_" + b"z" * 1000, -] - -VERY_LARGE_MESSAGES = [ - b"very_large_message_" + b"x" * 10000, # ~10KB - b"very_large_message_" + b"y" * 50000, # ~50KB - b"very_large_message_" + b"z" * 100000, # ~100KB -] - -SPECIAL_MESSAGES = [ - b"", # Empty message - b"\x00\x01\x02", # Binary data with null bytes - "unicode_test_你好".encode(), # Unicode - b"special_chars_!@#$%^&*()_+-=[]{}|;':\",./<>?", # Special characters - b"newlines\n\r\t", # Control characters -] - -BINARY_MESSAGES = [ - bytes(range(256)), # All possible byte values - b"\xff\xfe\xfd\xfc\xfb\xfa\xf9\xf8", # High byte values - b"\x00\x01\x02\x03\x04\x05\x06\x07", # Low byte values -] - -# Buffer test configurations -BUFFER_TEST_CONFIGS = [ - BufferTestConfig( - buffer_size=3, - overflow_strategy=Overflow.DROP_OLDEST, - message_count=5, - expected_behavior="drop_oldest", - description="Drop oldest messages when buffer is full", - ), - BufferTestConfig( - buffer_size=3, - overflow_strategy=Overflow.DROP_NEWEST, - message_count=5, - expected_behavior="drop_newest", - description="Drop newest messages when buffer is full", - ), - BufferTestConfig( - buffer_size=3, - overflow_strategy=Overflow.BLOCK, - message_count=5, - expected_behavior="block", - description="Block when buffer is full", - ), -] - -# Concurrency test configurations -CONCURRENCY_TEST_CONFIGS = [ - ConcurrencyTestConfig( - publisher_count=1, - subscriber_count=1, - messages_per_publisher=10, - test_duration=5.0, - description="Single publisher, single subscriber", - ), - ConcurrencyTestConfig( - publisher_count=3, - subscriber_count=1, - messages_per_publisher=10, - test_duration=5.0, - description="Multiple publishers, single subscriber", - ), - ConcurrencyTestConfig( - publisher_count=1, - subscriber_count=3, - messages_per_publisher=10, - test_duration=5.0, - description="Single publisher, multiple subscribers", - ), - ConcurrencyTestConfig( - publisher_count=3, - subscriber_count=3, - messages_per_publisher=10, - test_duration=5.0, - description="Multiple publishers, multiple subscribers", - ), -] - -# Error test configurations -ERROR_TEST_CONFIGS = [ - ErrorTestConfig( - error_type="invalid_buffer_size", - test_input=0, - expected_exception=ValueError, - description="Zero buffer size should raise ValueError", - ), - ErrorTestConfig( - error_type="invalid_buffer_size", - test_input=-1, - expected_exception=ValueError, - description="Negative buffer size should raise ValueError", - ), - ErrorTestConfig( - error_type="invalid_buffer_size", - test_input=1.5, - expected_exception=TypeError, - description="Float buffer size should raise TypeError", - ), - ErrorTestConfig( - error_type="invalid_buffer_size", - test_input="invalid", - expected_exception=TypeError, - description="String buffer size should raise TypeError", - ), -] - -# Topic name test cases -TOPIC_NAME_TEST_CASES = [ - "simple_topic", - "topic_with_underscores", - "topic-with-dashes", - "topic.with.dots", - "topic_with_numbers_123", - "UPPERCASE_TOPIC", - "mixed_Case_Topic", - "topic_with_symbols_!@#$%", - "very_long_topic_name_" + "x" * 100, - "unicode_topic_你好", - "topic:with:colons", - "topic/with/slashes", - "topic\\with\\backslashes", -] - -# Performance test configurations -PERFORMANCE_TEST_CONFIGS = [ - { - "name": "small_messages_high_frequency", - "message_size": 50, - "message_count": 1000, - "description": "Many small messages", - }, - { - "name": "medium_messages_medium_frequency", - "message_size": 500, - "message_count": 100, - "description": "Medium messages", - }, - { - "name": "large_messages_low_frequency", - "message_size": 5000, - "message_count": 10, - "description": "Large messages", - }, -] - -# Stress test configurations -STRESS_TEST_CONFIGS = [ - { - "name": "high_frequency_publishing", - "publisher_count": 5, - "messages_per_publisher": 100, - "subscriber_count": 3, - "description": "High frequency publishing with multiple publishers", - }, - { - "name": "many_subscribers", - "publisher_count": 1, - "messages_per_publisher": 50, - "subscriber_count": 10, - "description": "Many subscribers to single publisher", - }, - { - "name": "mixed_load", - "publisher_count": 3, - "messages_per_publisher": 100, - "subscriber_count": 5, - "description": "Mixed load with multiple publishers and subscribers", - }, -] - -# Edge case test data -EDGE_CASE_MESSAGES = [ - b"", # Empty message - b"\x00", # Single null byte - b"\xff", # Single max byte value - b"a", # Single ASCII character - "ä".encode(), # Single unicode character (2 bytes) - "𐍈".encode(), # Unicode character outside BMP (4 bytes) - b"\x00" * 1000, # 1000 null bytes - b"\xff" * 1000, # 1000 max byte values -] - -# Message validation test data -MESSAGE_VALIDATION_TEST_CASES = [ - { - "name": "valid_bytes", - "input": b"valid_message", - "should_pass": True, - "description": "Valid bytes message", - }, - { - "name": "empty_bytes", - "input": b"", - "should_pass": True, - "description": "Empty bytes message", - }, - { - "name": "binary_data", - "input": bytes(range(256)), - "should_pass": True, - "description": "Binary data with all byte values", - }, - { - "name": "large_message", - "input": b"x" * 1000000, # 1MB - "should_pass": True, - "description": "Large message (1MB)", - }, -] - -# Redis connection test scenarios -REDIS_CONNECTION_TEST_SCENARIOS = [ - { - "name": "normal_connection", - "should_fail": False, - "description": "Normal Redis connection", - }, - { - "name": "connection_timeout", - "should_fail": True, - "description": "Connection timeout scenario", - }, - { - "name": "connection_refused", - "should_fail": True, - "description": "Connection refused scenario", - }, -] - -# Test constants -DEFAULT_TIMEOUT = 10.0 -SHORT_TIMEOUT = 2.0 -LONG_TIMEOUT = 30.0 - -# Message size limits for testing -MAX_SMALL_MESSAGE_SIZE = 100 -MAX_MEDIUM_MESSAGE_SIZE = 1000 -MAX_LARGE_MESSAGE_SIZE = 10000 - -# Thread counts for concurrency testing -MIN_THREAD_COUNT = 1 -MAX_THREAD_COUNT = 10 -DEFAULT_THREAD_COUNT = 3 - -# Buffer sizes for testing -MIN_BUFFER_SIZE = 1 -MAX_BUFFER_SIZE = 1000 -DEFAULT_BUFFER_SIZE = 10 diff --git a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_helpers.py b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_helpers.py deleted file mode 100644 index 65f3007b01..0000000000 --- a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_helpers.py +++ /dev/null @@ -1,396 +0,0 @@ -""" -Test helper utilities for Redis broadcast channel integration tests. - -This module provides utility classes and functions for testing concurrent -operations, monitoring subscriptions, and measuring performance. -""" - -import logging -import threading -import time -from collections.abc import Callable -from typing import Any - -_logger = logging.getLogger(__name__) - - -class ConcurrentPublisher: - """ - Utility class for publishing messages concurrently from multiple threads. - - This class manages multiple publisher threads that can publish messages - to the same or different topics concurrently, useful for stress testing - and concurrency validation. - """ - - def __init__(self, producer, message_count: int = 10, delay: float = 0.0): - """ - Initialize the concurrent publisher. - - Args: - producer: The producer instance to publish with - message_count: Number of messages to publish per thread - delay: Delay between messages in seconds - """ - self.producer = producer - self.message_count = message_count - self.delay = delay - self.threads: list[threading.Thread] = [] - self.published_messages: list[list[bytes]] = [] - self._lock = threading.Lock() - self._started = False - - def start_publishers(self, thread_count: int = 3) -> None: - """ - Start multiple publisher threads. - - Args: - thread_count: Number of publisher threads to start - """ - if self._started: - raise RuntimeError("Publishers already started") - - self._started = True - - def _publisher(thread_id: int) -> None: - messages: list[bytes] = [] - for i in range(self.message_count): - message = f"thread_{thread_id}_msg_{i}".encode() - try: - self.producer.publish(message) - messages.append(message) - if self.delay > 0: - time.sleep(self.delay) - except Exception: - _logger.exception("Pubmsg=lisher %s", thread_id) - - with self._lock: - self.published_messages.append(messages) - - for thread_id in range(thread_count): - thread = threading.Thread( - target=_publisher, - args=(thread_id,), - name=f"publisher-{thread_id}", - daemon=True, - ) - thread.start() - self.threads.append(thread) - - def wait_for_completion(self, timeout: float = 30.0) -> bool: - """ - Wait for all publisher threads to complete. - - Args: - timeout: Maximum time to wait in seconds - - Returns: - bool: True if all threads completed successfully - """ - for thread in self.threads: - thread.join(timeout) - if thread.is_alive(): - return False - return True - - def get_all_messages(self) -> list[bytes]: - """ - Get all messages published by all threads. - - Returns: - list[bytes]: Flattened list of all published messages - """ - with self._lock: - all_messages = [] - for thread_messages in self.published_messages: - all_messages.extend(thread_messages) - return all_messages - - def get_thread_messages(self, thread_id: int) -> list[bytes]: - """ - Get messages published by a specific thread. - - Args: - thread_id: ID of the thread - - Returns: - list[bytes]: Messages published by the specified thread - """ - with self._lock: - if 0 <= thread_id < len(self.published_messages): - return self.published_messages[thread_id].copy() - return [] - - -class SubscriptionMonitor: - """ - Utility class for monitoring subscription activity in tests. - - This class monitors a subscription and tracks message reception, - errors, and completion status for testing purposes. - """ - - def __init__(self, subscription, timeout: float = 10.0): - """ - Initialize the subscription monitor. - - Args: - subscription: The subscription to monitor - timeout: Default timeout for operations - """ - self.subscription = subscription - self.timeout = timeout - self.messages: list[bytes] = [] - self.errors: list[Exception] = [] - self.completed = False - self._lock = threading.Lock() - self._condition = threading.Condition(self._lock) - self._monitor_thread: threading.Thread | None = None - self._start_time: float | None = None - - def start_monitoring(self) -> None: - """Start monitoring the subscription in a separate thread.""" - if self._monitor_thread is not None: - raise RuntimeError("Monitoring already started") - - self._start_time = time.time() - - def _monitor(): - try: - for message in self.subscription: - with self._lock: - self.messages.append(message) - self._condition.notify_all() - except Exception as e: - with self._lock: - self.errors.append(e) - self._condition.notify_all() - finally: - with self._lock: - self.completed = True - self._condition.notify_all() - - self._monitor_thread = threading.Thread( - target=_monitor, - name="subscription-monitor", - daemon=True, - ) - self._monitor_thread.start() - - def wait_for_messages(self, count: int, timeout: float | None = None) -> bool: - """ - Wait for a specific number of messages. - - Args: - count: Number of messages to wait for - timeout: Timeout in seconds (uses default if None) - - Returns: - bool: True if expected messages were received - """ - if timeout is None: - timeout = self.timeout - - deadline = time.time() + timeout - - with self._condition: - while len(self.messages) < count and not self.completed: - remaining = deadline - time.time() - if remaining <= 0: - return False - self._condition.wait(remaining) - - return len(self.messages) >= count - - def wait_for_completion(self, timeout: float | None = None) -> bool: - """ - Wait for monitoring to complete. - - Args: - timeout: Timeout in seconds (uses default if None) - - Returns: - bool: True if monitoring completed successfully - """ - if timeout is None: - timeout = self.timeout - - deadline = time.time() + timeout - - with self._condition: - while not self.completed: - remaining = deadline - time.time() - if remaining <= 0: - return False - self._condition.wait(remaining) - - return True - - def get_messages(self) -> list[bytes]: - """ - Get all received messages. - - Returns: - list[bytes]: Copy of received messages - """ - with self._lock: - return self.messages.copy() - - def get_error_count(self) -> int: - """ - Get the number of errors encountered. - - Returns: - int: Number of errors - """ - with self._lock: - return len(self.errors) - - def get_elapsed_time(self) -> float: - """ - Get the elapsed monitoring time. - - Returns: - float: Elapsed time in seconds - """ - if self._start_time is None: - return 0.0 - return time.time() - self._start_time - - def stop(self) -> None: - """Stop monitoring and close the subscription.""" - if self._monitor_thread is not None: - self.subscription.close() - self._monitor_thread.join(timeout=1.0) - - -def assert_message_order(received: list[bytes], expected: list[bytes]) -> bool: - """ - Assert that messages were received in the expected order. - - Args: - received: List of received messages - expected: List of expected messages in order - - Returns: - bool: True if order matches expected - """ - if len(received) != len(expected): - return False - - for i, (recv_msg, exp_msg) in enumerate(zip(received, expected)): - if recv_msg != exp_msg: - _logger.error("Message order mismatch at index %s: expected %s, got %s", i, exp_msg, recv_msg) - return False - - return True - - -def measure_throughput( - operation: Callable[[], Any], - duration: float = 1.0, -) -> tuple[float, int]: - """ - Measure the throughput of an operation over a specified duration. - - Args: - operation: The operation to measure - duration: Duration to run the operation in seconds - - Returns: - tuple[float, int]: (operations per second, total operations) - """ - start_time = time.time() - end_time = start_time + duration - count = 0 - - while time.time() < end_time: - try: - operation() - count += 1 - except Exception: - _logger.exception("Operation failed") - break - - elapsed = time.time() - start_time - ops_per_sec = count / elapsed if elapsed > 0 else 0.0 - - return ops_per_sec, count - - -def wait_for_condition( - condition: Callable[[], bool], - timeout: float = 10.0, - interval: float = 0.1, -) -> bool: - """ - Wait for a condition to become true. - - Args: - condition: Function that returns True when condition is met - timeout: Maximum time to wait in seconds - interval: Check interval in seconds - - Returns: - bool: True if condition was met within timeout - """ - deadline = time.time() + timeout - - while time.time() < deadline: - if condition(): - return True - time.sleep(interval) - - return False - - -def create_stress_test_messages( - count: int, - size: int = 100, -) -> list[bytes]: - """ - Create messages for stress testing. - - Args: - count: Number of messages to create - size: Size of each message in bytes - - Returns: - list[bytes]: List of test messages - """ - messages = [] - for i in range(count): - message = f"stress_test_msg_{i:06d}_".ljust(size, "x").encode() - messages.append(message) - return messages - - -def validate_message_integrity( - original_messages: list[bytes], - received_messages: list[bytes], -) -> dict[str, Any]: - """ - Validate the integrity of received messages. - - Args: - original_messages: Messages that were sent - received_messages: Messages that were received - - Returns: - dict[str, Any]: Validation results - """ - original_set = set(original_messages) - received_set = set(received_messages) - - missing_messages = original_set - received_set - extra_messages = received_set - original_set - - return { - "total_sent": len(original_messages), - "total_received": len(received_messages), - "missing_count": len(missing_messages), - "extra_count": len(extra_messages), - "missing_messages": list(missing_messages), - "extra_messages": list(extra_messages), - "integrity_ok": len(missing_messages) == 0 and len(extra_messages) == 0, - } diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py deleted file mode 100644 index 7fad603a6d..0000000000 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ /dev/null @@ -1,166 +0,0 @@ -"""TestContainers integration tests for ChatConversationApi status_count behavior.""" - -import json -import uuid - -from flask.testing import FlaskClient -from sqlalchemy.orm import Session - -from configs import dify_config -from constants import HEADER_NAME_CSRF_TOKEN -from core.workflow.enums import WorkflowExecutionStatus -from libs.datetime_utils import naive_utc_now -from libs.token import _real_cookie_name, generate_csrf_token -from models import Account, DifySetup, Tenant, TenantAccountJoin -from models.account import AccountStatus, TenantAccountRole -from models.enums import CreatorUserRole -from models.model import App, AppMode, Conversation, Message -from models.workflow import WorkflowRun -from services.account_service import AccountService - - -def _create_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]: - account = Account( - email=f"test-{uuid.uuid4()}@example.com", - name="Test User", - interface_language="en-US", - status=AccountStatus.ACTIVE, - ) - account.initialized_at = naive_utc_now() - db_session.add(account) - db_session.commit() - - tenant = Tenant(name="Test Tenant", status="normal") - db_session.add(tenant) - db_session.commit() - - join = TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - role=TenantAccountRole.OWNER, - current=True, - ) - db_session.add(join) - db_session.commit() - - account.set_tenant_id(tenant.id) - account.timezone = "UTC" - db_session.commit() - - dify_setup = DifySetup(version=dify_config.project.version) - db_session.add(dify_setup) - db_session.commit() - - return account, tenant - - -def _create_app(db_session: Session, tenant_id: str, account_id: str) -> App: - app = App( - tenant_id=tenant_id, - name="Test Chat App", - mode=AppMode.CHAT, - enable_site=True, - enable_api=True, - created_by=account_id, - ) - db_session.add(app) - db_session.commit() - return app - - -def _create_conversation(db_session: Session, app_id: str, account_id: str) -> Conversation: - conversation = Conversation( - app_id=app_id, - name="Test Conversation", - inputs={}, - status="normal", - mode=AppMode.CHAT, - from_source=CreatorUserRole.ACCOUNT, - from_account_id=account_id, - ) - db_session.add(conversation) - db_session.commit() - return conversation - - -def _create_workflow_run(db_session: Session, app_id: str, tenant_id: str, account_id: str) -> WorkflowRun: - workflow_run = WorkflowRun( - tenant_id=tenant_id, - app_id=app_id, - workflow_id=str(uuid.uuid4()), - type="chat", - triggered_from="app-run", - version="1.0.0", - graph=json.dumps({"nodes": [], "edges": []}), - inputs=json.dumps({"query": "test"}), - status=WorkflowExecutionStatus.PAUSED, - outputs=json.dumps({}), - elapsed_time=0.0, - total_tokens=0, - total_steps=0, - created_by_role=CreatorUserRole.ACCOUNT, - created_by=account_id, - created_at=naive_utc_now(), - ) - db_session.add(workflow_run) - db_session.commit() - return workflow_run - - -def _create_message( - db_session: Session, app_id: str, conversation_id: str, workflow_run_id: str, account_id: str -) -> Message: - message = Message( - app_id=app_id, - conversation_id=conversation_id, - query="Hello", - message={"type": "text", "content": "Hello"}, - answer="Hi there", - message_tokens=1, - answer_tokens=1, - message_unit_price=0.001, - answer_unit_price=0.001, - message_price_unit=0.001, - answer_price_unit=0.001, - currency="USD", - status="normal", - from_source=CreatorUserRole.ACCOUNT, - from_account_id=account_id, - workflow_run_id=workflow_run_id, - inputs={"query": "Hello"}, - ) - db_session.add(message) - db_session.commit() - return message - - -def test_chat_conversation_status_count_includes_paused( - db_session_with_containers: Session, - test_client_with_containers: FlaskClient, -): - account, tenant = _create_account_and_tenant(db_session_with_containers) - app = _create_app(db_session_with_containers, tenant.id, account.id) - conversation = _create_conversation(db_session_with_containers, app.id, account.id) - conversation_id = conversation.id - workflow_run = _create_workflow_run(db_session_with_containers, app.id, tenant.id, account.id) - _create_message(db_session_with_containers, app.id, conversation.id, workflow_run.id, account.id) - - access_token = AccountService.get_account_jwt_token(account) - csrf_token = generate_csrf_token(account.id) - cookie_name = _real_cookie_name("csrf_token") - - test_client_with_containers.set_cookie(cookie_name, csrf_token, domain="localhost") - response = test_client_with_containers.get( - f"/console/api/apps/{app.id}/chat-conversations", - headers={ - "Authorization": f"Bearer {access_token}", - HEADER_NAME_CSRF_TOKEN: csrf_token, - }, - ) - - assert response.status_code == 200 - payload = response.get_json() - assert payload is not None - assert payload["total"] == 1 - assert payload["data"][0]["id"] == conversation_id - assert payload["data"][0]["status_count"]["paused"] == 1 diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py deleted file mode 100644 index 079e4934bb..0000000000 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ /dev/null @@ -1,240 +0,0 @@ -"""TestContainers integration tests for HumanInputFormRepositoryImpl.""" - -from __future__ import annotations - -from uuid import uuid4 - -from sqlalchemy import Engine, select -from sqlalchemy.orm import Session - -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.workflow.nodes.human_input.entities import ( - DeliveryChannelConfig, - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, - FormDefinition, - HumanInputNodeData, - MemberRecipient, - UserAction, - WebAppDeliveryMethod, -) -from core.workflow.repositories.human_input_form_repository import FormCreateParams -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole -from models.human_input import ( - EmailExternalRecipientPayload, - EmailMemberRecipientPayload, - HumanInputForm, - HumanInputFormRecipient, - RecipientType, -) - - -def _create_tenant_with_members(session: Session, member_emails: list[str]) -> tuple[Tenant, list[Account]]: - tenant = Tenant(name="Test Tenant", status="normal") - session.add(tenant) - session.flush() - - members: list[Account] = [] - for index, email in enumerate(member_emails): - account = Account( - email=email, - name=f"Member {index}", - interface_language="en-US", - status="active", - ) - session.add(account) - session.flush() - - tenant_join = TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - role=TenantAccountRole.NORMAL, - current=True, - ) - session.add(tenant_join) - members.append(account) - - session.commit() - return tenant, members - - -def _build_form_params(delivery_methods: list[DeliveryChannelConfig]) -> FormCreateParams: - form_config = HumanInputNodeData( - title="Human Approval", - delivery_methods=delivery_methods, - form_content="

Approve?

", - user_actions=[UserAction(id="approve", title="Approve")], - ) - return FormCreateParams( - app_id=str(uuid4()), - workflow_execution_id=str(uuid4()), - node_id="human-input-node", - form_config=form_config, - rendered_content="

Approve?

", - delivery_methods=delivery_methods, - display_in_ui=False, - resolved_default_values={}, - ) - - -def _build_email_delivery( - whole_workspace: bool, recipients: list[MemberRecipient | ExternalRecipient] -) -> EmailDeliveryMethod: - return EmailDeliveryMethod( - config=EmailDeliveryConfig( - recipients=EmailRecipients(whole_workspace=whole_workspace, items=recipients), - subject="Approval Needed", - body="Please review", - ) - ) - - -class TestHumanInputFormRepositoryImplWithContainers: - def test_create_form_with_whole_workspace_recipients(self, db_session_with_containers: Session) -> None: - engine = db_session_with_containers.get_bind() - assert isinstance(engine, Engine) - tenant, members = _create_tenant_with_members( - db_session_with_containers, - member_emails=["member1@example.com", "member2@example.com"], - ) - - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) - params = _build_form_params( - delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])], - ) - - form_entity = repository.create_form(params) - - with Session(engine) as verification_session: - recipients = verification_session.scalars( - select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id) - ).all() - - assert len(recipients) == len(members) - member_payloads = [ - EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload) - for recipient in recipients - if recipient.recipient_type == RecipientType.EMAIL_MEMBER - ] - member_emails = {payload.email for payload in member_payloads} - assert member_emails == {member.email for member in members} - - def test_create_form_with_specific_members_and_external(self, db_session_with_containers: Session) -> None: - engine = db_session_with_containers.get_bind() - assert isinstance(engine, Engine) - tenant, members = _create_tenant_with_members( - db_session_with_containers, - member_emails=["primary@example.com", "secondary@example.com"], - ) - - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) - params = _build_form_params( - delivery_methods=[ - _build_email_delivery( - whole_workspace=False, - recipients=[ - MemberRecipient(user_id=members[0].id), - ExternalRecipient(email="external@example.com"), - ], - ) - ], - ) - - form_entity = repository.create_form(params) - - with Session(engine) as verification_session: - recipients = verification_session.scalars( - select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id) - ).all() - - member_recipient_payloads = [ - EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload) - for recipient in recipients - if recipient.recipient_type == RecipientType.EMAIL_MEMBER - ] - assert len(member_recipient_payloads) == 1 - assert member_recipient_payloads[0].user_id == members[0].id - - external_payloads = [ - EmailExternalRecipientPayload.model_validate_json(recipient.recipient_payload) - for recipient in recipients - if recipient.recipient_type == RecipientType.EMAIL_EXTERNAL - ] - assert len(external_payloads) == 1 - assert external_payloads[0].email == "external@example.com" - - def test_create_form_persists_default_values(self, db_session_with_containers: Session) -> None: - engine = db_session_with_containers.get_bind() - assert isinstance(engine, Engine) - tenant, _ = _create_tenant_with_members( - db_session_with_containers, - member_emails=["prefill@example.com"], - ) - - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) - resolved_values = {"greeting": "Hello!"} - params = FormCreateParams( - app_id=str(uuid4()), - workflow_execution_id=str(uuid4()), - node_id="human-input-node", - form_config=HumanInputNodeData( - title="Human Approval", - form_content="

Approve?

", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ), - rendered_content="

Approve?

", - delivery_methods=[], - display_in_ui=False, - resolved_default_values=resolved_values, - ) - - form_entity = repository.create_form(params) - - with Session(engine) as verification_session: - form_model = verification_session.scalars( - select(HumanInputForm).where(HumanInputForm.id == form_entity.id) - ).first() - - assert form_model is not None - definition = FormDefinition.model_validate_json(form_model.form_definition) - assert definition.default_values == resolved_values - - def test_create_form_persists_display_in_ui(self, db_session_with_containers: Session) -> None: - engine = db_session_with_containers.get_bind() - assert isinstance(engine, Engine) - tenant, _ = _create_tenant_with_members( - db_session_with_containers, - member_emails=["ui@example.com"], - ) - - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) - params = FormCreateParams( - app_id=str(uuid4()), - workflow_execution_id=str(uuid4()), - node_id="human-input-node", - form_config=HumanInputNodeData( - title="Human Approval", - form_content="

Approve?

", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - delivery_methods=[WebAppDeliveryMethod()], - ), - rendered_content="

Approve?

", - delivery_methods=[WebAppDeliveryMethod()], - display_in_ui=True, - resolved_default_values={}, - ) - - form_entity = repository.create_form(params) - - with Session(engine) as verification_session: - form_model = verification_session.scalars( - select(HumanInputForm).where(HumanInputForm.id == form_entity.id) - ).first() - - assert form_model is not None - definition = FormDefinition.model_validate_json(form_model.form_definition) - assert definition.display_in_ui is True diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py deleted file mode 100644 index 06d55177eb..0000000000 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ /dev/null @@ -1,336 +0,0 @@ -import time -import uuid -from datetime import timedelta -from unittest.mock import MagicMock - -import pytest -from sqlalchemy import delete, select -from sqlalchemy.orm import Session - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowType -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now -from models import Account -from models.account import Tenant, TenantAccountJoin, TenantAccountRole -from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom -from models.model import App, AppMode, IconType -from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun - - -def _mock_form_repository_without_submission() -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = False - repo.create_form.return_value = form_entity - repo.get_form.return_value = None - return repo - - -def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = True - form_entity.selected_action_id = action_id - form_entity.submitted_data = {} - form_entity.status = HumanInputFormStatus.WAITING - form_entity.expiration_time = naive_utc_now() + timedelta(hours=1) - repo.get_form.return_value = form_entity - return repo - - -def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - workflow_execution_id=workflow_execution_id, - app_id=app_id, - workflow_id=workflow_id, - user_id=user_id, - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph( - runtime_state: GraphRuntimeState, - tenant_id: str, - app_id: str, - workflow_id: str, - user_id: str, - form_repository: HumanInputFormRepository, -) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = GraphInitParams( - tenant_id=tenant_id, - app_id=app_id, - workflow_id=workflow_id, - graph_config=graph_config, - user_id=user_id, - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_data = StartNodeData(title="start", variables=[]) - start_node = StartNode( - id="start", - config={"id": "start", "data": start_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="human", - form_content="Awaiting human input", - inputs=[], - user_actions=[ - UserAction(id="continue", title="Continue"), - ], - ) - human_node = HumanInputNode( - id="human", - config={"id": "human", "data": human_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - form_repository=form_repository, - ) - - end_data = EndNodeData( - title="end", - outputs=[], - desc=None, - ) - end_node = EndNode( - id="end", - config={"id": "end", "data": end_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(human_node) - .add_node(end_node, from_node_id="human", source_handle="continue") - .build() - ) - - -def _build_generate_entity( - tenant_id: str, - app_id: str, - workflow_id: str, - workflow_execution_id: str, - user_id: str, -) -> WorkflowAppGenerateEntity: - app_config = WorkflowUIBasedAppConfig( - tenant_id=tenant_id, - app_id=app_id, - app_mode=AppMode.WORKFLOW, - workflow_id=workflow_id, - ) - return WorkflowAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - inputs={}, - files=[], - user_id=user_id, - stream=False, - invoke_from=InvokeFrom.DEBUGGER, - workflow_execution_id=workflow_execution_id, - ) - - -class TestHumanInputResumeNodeExecutionIntegration: - @pytest.fixture(autouse=True) - def setup_test_data(self, db_session_with_containers: Session): - tenant = Tenant( - name="Test Tenant", - status="normal", - ) - db_session_with_containers.add(tenant) - db_session_with_containers.commit() - - account = Account( - email="test@example.com", - name="Test User", - interface_language="en-US", - status="active", - ) - db_session_with_containers.add(account) - db_session_with_containers.commit() - - tenant_join = TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - role=TenantAccountRole.OWNER, - current=True, - ) - db_session_with_containers.add(tenant_join) - db_session_with_containers.commit() - - account.current_tenant = tenant - - app = App( - tenant_id=tenant.id, - name="Test App", - description="", - mode=AppMode.WORKFLOW.value, - icon_type=IconType.EMOJI.value, - icon="rocket", - icon_background="#4ECDC4", - enable_site=False, - enable_api=False, - api_rpm=0, - api_rph=0, - is_demo=False, - is_public=False, - is_universal=False, - max_active_requests=None, - created_by=account.id, - updated_by=account.id, - ) - db_session_with_containers.add(app) - db_session_with_containers.commit() - - workflow = Workflow( - tenant_id=tenant.id, - app_id=app.id, - type="workflow", - version="draft", - graph='{"nodes": [], "edges": []}', - features='{"file_upload": {"enabled": false}}', - created_by=account.id, - created_at=naive_utc_now(), - ) - db_session_with_containers.add(workflow) - db_session_with_containers.commit() - - self.session = db_session_with_containers - self.tenant = tenant - self.account = account - self.app = app - self.workflow = workflow - - yield - - self.session.execute(delete(WorkflowNodeExecutionModel)) - self.session.execute(delete(WorkflowRun)) - self.session.execute(delete(Workflow).where(Workflow.id == self.workflow.id)) - self.session.execute(delete(App).where(App.id == self.app.id)) - self.session.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == self.tenant.id)) - self.session.execute(delete(Account).where(Account.id == self.account.id)) - self.session.execute(delete(Tenant).where(Tenant.id == self.tenant.id)) - self.session.commit() - - def _build_persistence_layer(self, execution_id: str) -> WorkflowPersistenceLayer: - generate_entity = _build_generate_entity( - tenant_id=self.tenant.id, - app_id=self.app.id, - workflow_id=self.workflow.id, - workflow_execution_id=execution_id, - user_id=self.account.id, - ) - execution_repo = SQLAlchemyWorkflowExecutionRepository( - session_factory=self.session.get_bind(), - user=self.account, - app_id=self.app.id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, - ) - node_execution_repo = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=self.session.get_bind(), - user=self.account, - app_id=self.app.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - return WorkflowPersistenceLayer( - application_generate_entity=generate_entity, - workflow_info=PersistenceWorkflowInfo( - workflow_id=self.workflow.id, - workflow_type=WorkflowType.WORKFLOW, - version=self.workflow.version, - graph_data=self.workflow.graph_dict, - ), - workflow_execution_repository=execution_repo, - workflow_node_execution_repository=node_execution_repo, - ) - - def _run_graph(self, graph: Graph, runtime_state: GraphRuntimeState, execution_id: str) -> None: - engine = GraphEngine( - workflow_id=self.workflow.id, - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - ) - engine.layer(self._build_persistence_layer(execution_id)) - for _ in engine.run(): - continue - - def test_resume_human_input_does_not_create_duplicate_node_execution(self): - execution_id = str(uuid.uuid4()) - runtime_state = _build_runtime_state( - workflow_execution_id=execution_id, - app_id=self.app.id, - workflow_id=self.workflow.id, - user_id=self.account.id, - ) - pause_repo = _mock_form_repository_without_submission() - paused_graph = _build_graph( - runtime_state, - self.tenant.id, - self.app.id, - self.workflow.id, - self.account.id, - pause_repo, - ) - self._run_graph(paused_graph, runtime_state, execution_id) - - snapshot = runtime_state.dumps() - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resume_repo = _mock_form_repository_with_submission(action_id="continue") - resumed_graph = _build_graph( - resumed_state, - self.tenant.id, - self.app.id, - self.workflow.id, - self.account.id, - resume_repo, - ) - self._run_graph(resumed_graph, resumed_state, execution_id) - - stmt = select(WorkflowNodeExecutionModel).where( - WorkflowNodeExecutionModel.workflow_run_id == execution_id, - WorkflowNodeExecutionModel.node_id == "human", - ) - records = self.session.execute(stmt).scalars().all() - assert len(records) == 1 - assert records[0].status != "paused" - assert records[0].triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN - assert records[0].created_by_role == CreatorUserRole.ACCOUNT diff --git a/api/tests/test_containers_integration_tests/helpers/__init__.py b/api/tests/test_containers_integration_tests/helpers/__init__.py deleted file mode 100644 index 40d03889a9..0000000000 --- a/api/tests/test_containers_integration_tests/helpers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Helper utilities for integration tests.""" diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py deleted file mode 100644 index 19d7772c39..0000000000 --- a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py +++ /dev/null @@ -1,154 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from datetime import datetime, timedelta -from decimal import Decimal -from uuid import uuid4 - -from core.workflow.nodes.human_input.entities import FormDefinition, UserAction -from models.account import Account, Tenant, TenantAccountJoin -from models.execution_extra_content import HumanInputContent -from models.human_input import HumanInputForm, HumanInputFormStatus -from models.model import App, Conversation, Message - - -@dataclass -class HumanInputMessageFixture: - app: App - account: Account - conversation: Conversation - message: Message - form: HumanInputForm - action_id: str - action_text: str - node_title: str - - -def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: - tenant = Tenant(name=f"Tenant {uuid4()}") - db_session.add(tenant) - db_session.flush() - - account = Account( - name=f"Account {uuid4()}", - email=f"human_input_{uuid4()}@example.com", - password="hashed-password", - password_salt="salt", - interface_language="en-US", - timezone="UTC", - ) - db_session.add(account) - db_session.flush() - - tenant_join = TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - role="owner", - current=True, - ) - db_session.add(tenant_join) - db_session.flush() - - app = App( - tenant_id=tenant.id, - name=f"App {uuid4()}", - description="", - mode="chat", - icon_type="emoji", - icon="🤖", - icon_background="#FFFFFF", - enable_site=False, - enable_api=True, - api_rpm=100, - api_rph=100, - is_demo=False, - is_public=False, - is_universal=False, - created_by=account.id, - updated_by=account.id, - ) - db_session.add(app) - db_session.flush() - - conversation = Conversation( - app_id=app.id, - mode="chat", - name="Test Conversation", - summary="", - introduction="", - system_instruction="", - status="normal", - invoke_from="console", - from_source="console", - from_account_id=account.id, - from_end_user_id=None, - ) - conversation.inputs = {} - db_session.add(conversation) - db_session.flush() - - workflow_run_id = str(uuid4()) - message = Message( - app_id=app.id, - conversation_id=conversation.id, - inputs={}, - query="Human input query", - message={"messages": []}, - answer="Human input answer", - message_tokens=50, - message_unit_price=Decimal("0.001"), - answer_tokens=80, - answer_unit_price=Decimal("0.001"), - provider_response_latency=0.5, - currency="USD", - from_source="console", - from_account_id=account.id, - workflow_run_id=workflow_run_id, - ) - db_session.add(message) - db_session.flush() - - action_id = "approve" - action_text = "Approve request" - node_title = "Approval" - form_definition = FormDefinition( - form_content="content", - inputs=[], - user_actions=[UserAction(id=action_id, title=action_text)], - rendered_content="Rendered block", - expiration_time=datetime.utcnow() + timedelta(days=1), - node_title=node_title, - display_in_ui=True, - ) - form = HumanInputForm( - tenant_id=tenant.id, - app_id=app.id, - workflow_run_id=workflow_run_id, - node_id="node-id", - form_definition=form_definition.model_dump_json(), - rendered_content="Rendered block", - status=HumanInputFormStatus.SUBMITTED, - expiration_time=datetime.utcnow() + timedelta(days=1), - selected_action_id=action_id, - ) - db_session.add(form) - db_session.flush() - - content = HumanInputContent( - workflow_run_id=workflow_run_id, - message_id=message.id, - form_id=form.id, - ) - db_session.add(content) - db_session.commit() - - return HumanInputMessageFixture( - app=app, - account=account, - conversation=conversation, - message=message, - form=form, - action_id=action_id, - action_text=action_text, - node_title=node_title, - ) diff --git a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py index 43915a204d..d612e70910 100644 --- a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py +++ b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py @@ -16,7 +16,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import pytest import redis -from redis.cluster import RedisCluster from testcontainers.redis import RedisContainer from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic @@ -333,95 +332,3 @@ class TestShardedRedisBroadcastChannelIntegration: # Verify subscriptions are cleaned up topic_subscribers_after = self._get_sharded_numsub(redis_client, topic_name) assert topic_subscribers_after == 0 - - -class TestShardedRedisBroadcastChannelClusterIntegration: - """Integration tests for sharded pub/sub with RedisCluster client.""" - - @pytest.fixture(scope="class") - def redis_cluster_container(self) -> Iterator[RedisContainer]: - """Create a Redis 7 container with cluster mode enabled.""" - command = ( - "redis-server --port 6379 " - "--cluster-enabled yes " - "--cluster-config-file nodes.conf " - "--cluster-node-timeout 5000 " - "--appendonly no " - "--protected-mode no" - ) - with RedisContainer(image="redis:7-alpine").with_command(command) as container: - yield container - - @classmethod - def _get_test_topic_name(cls) -> str: - return f"test_sharded_cluster_topic_{uuid.uuid4()}" - - @staticmethod - def _ensure_single_node_cluster(host: str, port: int) -> None: - client = redis.Redis(host=host, port=port, decode_responses=False) - client.config_set("cluster-announce-ip", host) - client.config_set("cluster-announce-port", port) - slots = client.execute_command("CLUSTER", "SLOTS") - if not slots: - client.execute_command("CLUSTER", "ADDSLOTSRANGE", 0, 16383) - - deadline = time.time() + 5.0 - while time.time() < deadline: - info = client.execute_command("CLUSTER", "INFO") - info_text = info.decode("utf-8") if isinstance(info, (bytes, bytearray)) else str(info) - if "cluster_state:ok" in info_text: - return - time.sleep(0.05) - raise RuntimeError("Redis cluster did not become ready in time") - - @pytest.fixture(scope="class") - def redis_cluster_client(self, redis_cluster_container: RedisContainer) -> RedisCluster: - host = redis_cluster_container.get_container_host_ip() - port = int(redis_cluster_container.get_exposed_port(6379)) - self._ensure_single_node_cluster(host, port) - return RedisCluster(host=host, port=port, decode_responses=False) - - @pytest.fixture - def broadcast_channel(self, redis_cluster_client: RedisCluster) -> BroadcastChannel: - return ShardedRedisBroadcastChannel(redis_cluster_client) - - def test_cluster_sharded_pubsub_delivers_message(self, broadcast_channel: BroadcastChannel): - """Ensure sharded subscription receives messages when using RedisCluster client.""" - topic_name = self._get_test_topic_name() - message = b"cluster sharded message" - - topic = broadcast_channel.topic(topic_name) - producer = topic.as_producer() - subscription = topic.subscribe() - ready_event = threading.Event() - - def consumer_thread() -> list[bytes]: - received = [] - try: - _ = subscription.receive(0.01) - except SubscriptionClosedError: - return received - ready_event.set() - deadline = time.time() + 5.0 - while time.time() < deadline: - msg = subscription.receive(timeout=0.1) - if msg is None: - continue - received.append(msg) - break - subscription.close() - return received - - def producer_thread(): - if not ready_event.wait(timeout=2.0): - pytest.fail("subscriber did not become ready before publish") - producer.publish(message) - - with ThreadPoolExecutor(max_workers=2) as executor: - consumer_future = executor.submit(consumer_thread) - producer_future = executor.submit(producer_thread) - - producer_future.result(timeout=5.0) - received_messages = consumer_future.result(timeout=5.0) - - assert received_messages == [message] diff --git a/api/tests/test_containers_integration_tests/libs/test_rate_limiter_integration.py b/api/tests/test_containers_integration_tests/libs/test_rate_limiter_integration.py deleted file mode 100644 index 178fc2e4fb..0000000000 --- a/api/tests/test_containers_integration_tests/libs/test_rate_limiter_integration.py +++ /dev/null @@ -1,25 +0,0 @@ -""" -Integration tests for RateLimiter using testcontainers Redis. -""" - -import uuid - -import pytest - -from extensions.ext_redis import redis_client -from libs import helper as helper_module - - -@pytest.mark.usefixtures("flask_app_with_containers") -def test_rate_limiter_counts_multiple_attempts_in_same_second(monkeypatch): - prefix = f"test_rate_limit:{uuid.uuid4().hex}" - limiter = helper_module.RateLimiter(prefix=prefix, max_attempts=2, time_window=60) - key = limiter._get_key("203.0.113.10") - - redis_client.delete(key) - monkeypatch.setattr(helper_module.time, "time", lambda: 1_700_000_000) - - limiter.increment_rate_limit("203.0.113.10") - limiter.increment_rate_limit("203.0.113.10") - - assert limiter.is_rate_limited("203.0.113.10") is True diff --git a/api/tests/test_containers_integration_tests/models/test_account.py b/api/tests/test_containers_integration_tests/models/test_account.py deleted file mode 100644 index 078dc0e8de..0000000000 --- a/api/tests/test_containers_integration_tests/models/test_account.py +++ /dev/null @@ -1,79 +0,0 @@ -# import secrets - -# import pytest -# from sqlalchemy import select -# from sqlalchemy.orm import Session -# from sqlalchemy.orm.exc import DetachedInstanceError - -# from libs.datetime_utils import naive_utc_now -# from models.account import Account, Tenant, TenantAccountJoin - - -# @pytest.fixture -# def session(db_session_with_containers): -# with Session(db_session_with_containers.get_bind()) as session: -# yield session - - -# @pytest.fixture -# def account(session): -# account = Account( -# name="test account", -# email=f"test_{secrets.token_hex(8)}@example.com", -# ) -# session.add(account) -# session.commit() -# return account - - -# @pytest.fixture -# def tenant(session): -# tenant = Tenant(name="test tenant") -# session.add(tenant) -# session.commit() -# return tenant - - -# @pytest.fixture -# def tenant_account_join(session, account, tenant): -# tenant_join = TenantAccountJoin(account_id=account.id, tenant_id=tenant.id) -# session.add(tenant_join) -# session.commit() -# yield tenant_join -# session.delete(tenant_join) -# session.commit() - - -# class TestAccountTenant: -# def test_set_current_tenant_should_reload_tenant( -# self, -# db_session_with_containers, -# account, -# tenant, -# tenant_account_join, -# ): -# with Session(db_session_with_containers.get_bind(), expire_on_commit=True) as session: -# scoped_tenant = session.scalars(select(Tenant).where(Tenant.id == tenant.id)).one() -# account.current_tenant = scoped_tenant -# scoped_tenant.created_at = naive_utc_now() -# # session.commit() - -# # Ensure the tenant used in assignment is detached. -# with pytest.raises(DetachedInstanceError): -# _ = scoped_tenant.name - -# assert account._current_tenant.id == tenant.id -# assert account._current_tenant.id == tenant.id - -# def test_set_tenant_id_should_load_tenant_as_not_expire( -# self, -# flask_app_with_containers, -# account, -# tenant, -# tenant_account_join, -# ): -# with flask_app_with_containers.test_request_context(): -# account.set_tenant_id(tenant.id) - -# assert account._current_tenant.id == tenant.id -# assert account._current_tenant.id == tenant.id diff --git a/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py deleted file mode 100644 index c9058626d1..0000000000 --- a/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from sqlalchemy.orm import sessionmaker - -from extensions.ext_database import db -from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository -from tests.test_containers_integration_tests.helpers.execution_extra_content import ( - create_human_input_message_fixture, -) - - -def test_get_by_message_ids_returns_human_input_content(db_session_with_containers): - fixture = create_human_input_message_fixture(db_session_with_containers) - repository = SQLAlchemyExecutionExtraContentRepository( - session_maker=sessionmaker(bind=db.engine, expire_on_commit=False) - ) - - results = repository.get_by_message_ids([fixture.message.id]) - - assert len(results) == 1 - assert len(results[0]) == 1 - content = results[0][0] - assert content.submitted is True - assert content.form_submission_data is not None - assert content.form_submission_data.action_id == fixture.action_id - assert content.form_submission_data.action_text == fixture.action_text - assert content.form_submission_data.rendered_content == fixture.form.rendered_content diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 4b6b5048a1..4d4e77a802 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -2293,12 +2293,6 @@ class TestRegisterService: mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False - from extensions.ext_database import db - from models.model import DifySetup - - db.session.query(DifySetup).delete() - db.session.commit() - # Execute setup RegisterService.setup( email=admin_email, @@ -2309,7 +2303,9 @@ class TestRegisterService: ) # Verify account was created + from extensions.ext_database import db from models import Account + from models.model import DifySetup account = db.session.query(Account).filter_by(email=admin_email).first() assert account is not None diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 81bfa0ea20..476f58585d 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -1,5 +1,5 @@ import uuid -from unittest.mock import ANY, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest from faker import Faker @@ -26,7 +26,6 @@ class TestAppGenerateService: patch("services.app_generate_service.AgentChatAppGenerator") as mock_agent_chat_generator, patch("services.app_generate_service.AdvancedChatAppGenerator") as mock_advanced_chat_generator, patch("services.app_generate_service.WorkflowAppGenerator") as mock_workflow_generator, - patch("services.app_generate_service.MessageBasedAppGenerator") as mock_message_based_generator, patch("services.account_service.FeatureService") as mock_account_feature_service, patch("services.app_generate_service.dify_config") as mock_dify_config, patch("configs.dify_config") as mock_global_dify_config, @@ -39,13 +38,9 @@ class TestAppGenerateService: # Setup default mock returns for workflow service mock_workflow_service_instance = mock_workflow_service.return_value - mock_published_workflow = MagicMock(spec=Workflow) - mock_published_workflow.id = str(uuid.uuid4()) - mock_workflow_service_instance.get_published_workflow.return_value = mock_published_workflow - mock_draft_workflow = MagicMock(spec=Workflow) - mock_draft_workflow.id = str(uuid.uuid4()) - mock_workflow_service_instance.get_draft_workflow.return_value = mock_draft_workflow - mock_workflow_service_instance.get_published_workflow_by_id.return_value = mock_published_workflow + mock_workflow_service_instance.get_published_workflow.return_value = MagicMock(spec=Workflow) + mock_workflow_service_instance.get_draft_workflow.return_value = MagicMock(spec=Workflow) + mock_workflow_service_instance.get_published_workflow_by_id.return_value = MagicMock(spec=Workflow) # Setup default mock returns for rate limiting mock_rate_limit_instance = mock_rate_limit.return_value @@ -71,8 +66,6 @@ class TestAppGenerateService: mock_advanced_chat_generator_instance.generate.return_value = ["advanced_chat_response"] mock_advanced_chat_generator_instance.single_iteration_generate.return_value = ["single_iteration_response"] mock_advanced_chat_generator_instance.single_loop_generate.return_value = ["single_loop_response"] - mock_advanced_chat_generator_instance.retrieve_events.return_value = ["advanced_chat_events"] - mock_advanced_chat_generator_instance.convert_to_event_stream.return_value = ["advanced_chat_stream"] mock_advanced_chat_generator.convert_to_event_stream.return_value = ["advanced_chat_stream"] mock_workflow_generator_instance = mock_workflow_generator.return_value @@ -83,8 +76,6 @@ class TestAppGenerateService: mock_workflow_generator_instance.single_loop_generate.return_value = ["workflow_single_loop_response"] mock_workflow_generator.convert_to_event_stream.return_value = ["workflow_stream"] - mock_message_based_generator.retrieve_events.return_value = ["workflow_events"] - # Setup default mock returns for account service mock_account_feature_service.get_system_features.return_value.is_allow_register = True @@ -97,7 +88,6 @@ class TestAppGenerateService: mock_global_dify_config.BILLING_ENABLED = False mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100 mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000 - mock_global_dify_config.HOSTED_POOL_CREDITS = 1000 yield { "billing_service": mock_billing_service, @@ -108,7 +98,6 @@ class TestAppGenerateService: "agent_chat_generator": mock_agent_chat_generator, "advanced_chat_generator": mock_advanced_chat_generator, "workflow_generator": mock_workflow_generator, - "message_based_generator": mock_message_based_generator, "account_feature_service": mock_account_feature_service, "dify_config": mock_dify_config, "global_dify_config": mock_global_dify_config, @@ -291,10 +280,8 @@ class TestAppGenerateService: assert result == ["test_response"] # Verify advanced chat generator was called - mock_external_service_dependencies["advanced_chat_generator"].return_value.retrieve_events.assert_called_once() - mock_external_service_dependencies[ - "advanced_chat_generator" - ].return_value.convert_to_event_stream.assert_called_once() + mock_external_service_dependencies["advanced_chat_generator"].return_value.generate.assert_called_once() + mock_external_service_dependencies["advanced_chat_generator"].convert_to_event_stream.assert_called_once() def test_generate_workflow_mode_success(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -317,7 +304,7 @@ class TestAppGenerateService: assert result == ["test_response"] # Verify workflow generator was called - mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once() + mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once() mock_external_service_dependencies["workflow_generator"].convert_to_event_stream.assert_called_once() def test_generate_with_specific_workflow_id(self, db_session_with_containers, mock_external_service_dependencies): @@ -983,27 +970,14 @@ class TestAppGenerateService: } # Execute the method under test - with patch("services.app_generate_service.AppExecutionParams") as mock_exec_params: - mock_payload = MagicMock() - mock_payload.workflow_run_id = fake.uuid4() - mock_payload.model_dump_json.return_value = "{}" - mock_exec_params.new.return_value = mock_payload - - result = AppGenerateService.generate( - app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True - ) + result = AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) # Verify the result assert result == ["test_response"] - # Verify payload was built with complex args - mock_exec_params.new.assert_called_once() - call_kwargs = mock_exec_params.new.call_args.kwargs - assert call_kwargs["args"] == args - - # Verify workflow streaming event retrieval was used - mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once_with( - ANY, - mock_payload.workflow_run_id, - on_subscribe=ANY, - ) + # Verify workflow generator was called with complex args + mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once() + call_args = mock_external_service_dependencies["workflow_generator"].return_value.generate.call_args + assert call_args[1]["args"] == args diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py deleted file mode 100644 index 9c978f830f..0000000000 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ /dev/null @@ -1,112 +0,0 @@ -import json -import uuid -from unittest.mock import MagicMock - -import pytest - -from core.workflow.enums import NodeType -from core.workflow.nodes.human_input.entities import ( - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, - HumanInputNodeData, -) -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole -from models.model import App, AppMode -from models.workflow import Workflow, WorkflowType -from services.workflow_service import WorkflowService - - -def _create_app_with_draft_workflow(session, *, delivery_method_id: uuid.UUID) -> tuple[App, Account]: - tenant = Tenant(name="Test Tenant") - account = Account(name="Tester", email="tester@example.com") - session.add_all([tenant, account]) - session.flush() - - session.add( - TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - current=True, - role=TenantAccountRole.OWNER.value, - ) - ) - - app = App( - tenant_id=tenant.id, - name="Test App", - description="", - mode=AppMode.WORKFLOW.value, - icon_type="emoji", - icon="app", - icon_background="#ffffff", - enable_site=True, - enable_api=True, - created_by=account.id, - updated_by=account.id, - ) - session.add(app) - session.flush() - - email_method = EmailDeliveryMethod( - id=delivery_method_id, - enabled=True, - config=EmailDeliveryConfig( - recipients=EmailRecipients( - whole_workspace=False, - items=[ExternalRecipient(email="recipient@example.com")], - ), - subject="Test {{recipient_email}}", - body="Body {{#url#}} {{form_content}}", - ), - ) - node_data = HumanInputNodeData( - title="Human Input", - delivery_methods=[email_method], - form_content="Hello Human Input", - inputs=[], - user_actions=[], - ).model_dump(mode="json") - node_data["type"] = NodeType.HUMAN_INPUT.value - graph = json.dumps({"nodes": [{"id": "human-node", "data": node_data}], "edges": []}) - - workflow = Workflow.new( - tenant_id=tenant.id, - app_id=app.id, - type=WorkflowType.WORKFLOW.value, - version=Workflow.VERSION_DRAFT, - graph=graph, - features=json.dumps({}), - created_by=account.id, - environment_variables=[], - conversation_variables=[], - rag_pipeline_variables=[], - ) - session.add(workflow) - session.commit() - - return app, account - - -def test_human_input_delivery_test_sends_email( - db_session_with_containers, - monkeypatch: pytest.MonkeyPatch, -) -> None: - delivery_method_id = uuid.uuid4() - app, account = _create_app_with_draft_workflow(db_session_with_containers, delivery_method_id=delivery_method_id) - - send_mock = MagicMock() - monkeypatch.setattr("services.human_input_delivery_test_service.mail.is_inited", lambda: True) - monkeypatch.setattr("services.human_input_delivery_test_service.mail.send", send_mock) - - service = WorkflowService() - service.test_human_input_delivery( - app_model=app, - account=account, - node_id="human-node", - delivery_method_id=str(delivery_method_id), - ) - - assert send_mock.call_count == 1 - assert send_mock.call_args.kwargs["to"] == "recipient@example.com" diff --git a/api/tests/test_containers_integration_tests/services/test_message_service_execution_extra_content.py b/api/tests/test_containers_integration_tests/services/test_message_service_execution_extra_content.py deleted file mode 100644 index 44e5a82868..0000000000 --- a/api/tests/test_containers_integration_tests/services/test_message_service_execution_extra_content.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import annotations - -import pytest - -from services.message_service import MessageService -from tests.test_containers_integration_tests.helpers.execution_extra_content import ( - create_human_input_message_fixture, -) - - -@pytest.mark.usefixtures("flask_req_ctx_with_containers") -def test_pagination_returns_extra_contents(db_session_with_containers): - fixture = create_human_input_message_fixture(db_session_with_containers) - - pagination = MessageService.pagination_by_first_id( - app_model=fixture.app, - user=fixture.account, - conversation_id=fixture.conversation.id, - first_id=None, - limit=10, - ) - - assert pagination.data - message = pagination.data[0] - assert message.extra_contents == [ - { - "type": "human_input", - "workflow_run_id": fixture.message.workflow_run_id, - "submitted": True, - "form_submission_data": { - "node_id": fixture.form.node_id, - "node_title": fixture.node_title, - "rendered_content": fixture.form.rendered_content, - "action_id": fixture.action_id, - "action_text": fixture.action_text, - }, - } - ] diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index 3a88081db3..23c4eeb82f 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -465,27 +465,6 @@ class TestWorkflowRunService: db.session.add(node_execution) node_executions.append(node_execution) - paused_node_execution = WorkflowNodeExecutionModel( - tenant_id=app.tenant_id, - app_id=app.id, - workflow_id=workflow_run.workflow_id, - triggered_from="workflow-run", - workflow_run_id=workflow_run.id, - index=99, - node_id="node_paused", - node_type="human_input", - title="Paused Node", - inputs=json.dumps({"input": "paused"}), - process_data=json.dumps({"process": "paused"}), - status="paused", - elapsed_time=0.5, - execution_metadata=json.dumps({"tokens": 0}), - created_by_role=CreatorUserRole.ACCOUNT, - created_by=account.id, - created_at=datetime.now(UTC), - ) - db.session.add(paused_node_execution) - db.session.commit() # Act: Execute the method under test @@ -494,19 +473,16 @@ class TestWorkflowRunService: # Assert: Verify the expected outcomes assert result is not None - assert len(result) == 4 + assert len(result) == 3 # Verify node execution properties - statuses = [node_execution.status for node_execution in result] - assert "paused" in statuses - assert statuses.count("succeeded") == 3 - assert statuses.count("paused") == 1 - for node_execution in result: assert node_execution.tenant_id == app.tenant_id assert node_execution.app_id == app.id assert node_execution.workflow_run_id == workflow_run.id - assert node_execution.node_id.startswith("node_") + assert node_execution.index in [0, 1, 2] # Check that index is one of the expected values + assert node_execution.node_id.startswith("node_") # Check that node_id starts with "node_" + assert node_execution.status == "succeeded" def test_get_workflow_run_node_executions_empty( self, db_session_with_containers, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index acd9d78c91..3d46735a1a 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -4,7 +4,6 @@ from unittest.mock import patch import pytest from faker import Faker -from core.tools.errors import WorkflowToolHumanInputNotSupportedError from models.tools import WorkflowToolProvider from models.workflow import Workflow as WorkflowModel from services.account_service import AccountService, TenantService @@ -508,62 +507,6 @@ class TestWorkflowToolManageService: assert tool_count == 0 - def test_create_workflow_tool_human_input_node_error( - self, db_session_with_containers, mock_external_service_dependencies - ): - """ - Test workflow tool creation fails when workflow contains human input nodes. - - This test verifies: - - Human input nodes prevent workflow tool publishing - - Correct error message - - No database changes when workflow is invalid - """ - fake = Faker() - - # Create test data - app, account, workflow = self._create_test_app_and_account( - db_session_with_containers, mock_external_service_dependencies - ) - - workflow.graph = json.dumps( - { - "nodes": [ - { - "id": "human_input_node", - "data": {"type": "human-input"}, - } - ] - } - ) - - tool_parameters = self._create_test_workflow_tool_parameters() - with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: - WorkflowToolManageService.create_workflow_tool( - user_id=account.id, - tenant_id=account.current_tenant.id, - workflow_app_id=app.id, - name=fake.word(), - label=fake.word(), - icon={"type": "emoji", "emoji": "🔧"}, - description=fake.text(max_nb_chars=200), - parameters=tool_parameters, - ) - - assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - - from extensions.ext_database import db - - tool_count = ( - db.session.query(WorkflowToolProvider) - .where( - WorkflowToolProvider.tenant_id == account.current_tenant.id, - ) - .count() - ) - - assert tool_count == 0 - def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): """ Test successful workflow tool update with valid parameters. @@ -650,80 +593,6 @@ class TestWorkflowToolManageService: mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called() mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called() - def test_update_workflow_tool_human_input_node_error( - self, db_session_with_containers, mock_external_service_dependencies - ): - """ - Test workflow tool update fails when workflow contains human input nodes. - - This test verifies: - - Human input nodes prevent workflow tool updates - - Correct error message - - Existing tool data remains unchanged - """ - fake = Faker() - - # Create test data - app, account, workflow = self._create_test_app_and_account( - db_session_with_containers, mock_external_service_dependencies - ) - - # Create initial workflow tool - initial_tool_name = fake.word() - initial_tool_parameters = self._create_test_workflow_tool_parameters() - WorkflowToolManageService.create_workflow_tool( - user_id=account.id, - tenant_id=account.current_tenant.id, - workflow_app_id=app.id, - name=initial_tool_name, - label=fake.word(), - icon={"type": "emoji", "emoji": "🔧"}, - description=fake.text(max_nb_chars=200), - parameters=initial_tool_parameters, - ) - - from extensions.ext_database import db - - created_tool = ( - db.session.query(WorkflowToolProvider) - .where( - WorkflowToolProvider.tenant_id == account.current_tenant.id, - WorkflowToolProvider.app_id == app.id, - ) - .first() - ) - - original_name = created_tool.name - - workflow.graph = json.dumps( - { - "nodes": [ - { - "id": "human_input_node", - "data": {"type": "human-input"}, - } - ] - } - ) - db.session.commit() - - with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: - WorkflowToolManageService.update_workflow_tool( - user_id=account.id, - tenant_id=account.current_tenant.id, - workflow_tool_id=created_tool.id, - name=fake.word(), - label=fake.word(), - icon={"type": "emoji", "emoji": "⚙️"}, - description=fake.text(max_nb_chars=200), - parameters=initial_tool_parameters, - ) - - assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - - db.session.refresh(created_tool) - assert created_tool.name == original_name - def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): """ Test workflow tool update fails when tool does not exist. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py deleted file mode 100644 index 5fd6c56f7a..0000000000 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ /dev/null @@ -1,214 +0,0 @@ -import uuid -from datetime import UTC, datetime -from unittest.mock import patch - -import pytest - -from configs import dify_config -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.nodes.human_input.entities import ( - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, - HumanInputNodeData, - MemberRecipient, -) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from extensions.ext_storage import storage -from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole -from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom -from models.human_input import HumanInputDelivery, HumanInputForm, HumanInputFormRecipient -from models.model import AppMode -from models.workflow import WorkflowPause, WorkflowRun, WorkflowType -from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task - - -@pytest.fixture(autouse=True) -def cleanup_database(db_session_with_containers): - db_session_with_containers.query(HumanInputFormRecipient).delete() - db_session_with_containers.query(HumanInputDelivery).delete() - db_session_with_containers.query(HumanInputForm).delete() - db_session_with_containers.query(WorkflowPause).delete() - db_session_with_containers.query(WorkflowRun).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() - db_session_with_containers.commit() - - -def _create_workspace_member(db_session_with_containers): - account = Account( - email="owner@example.com", - name="Owner", - password="password", - interface_language="en-US", - status=AccountStatus.ACTIVE, - ) - account.created_at = datetime.now(UTC) - account.updated_at = datetime.now(UTC) - db_session_with_containers.add(account) - db_session_with_containers.commit() - db_session_with_containers.refresh(account) - - tenant = Tenant(name="Test Tenant") - tenant.created_at = datetime.now(UTC) - tenant.updated_at = datetime.now(UTC) - db_session_with_containers.add(tenant) - db_session_with_containers.commit() - db_session_with_containers.refresh(tenant) - - tenant_join = TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - role=TenantAccountRole.OWNER, - ) - tenant_join.created_at = datetime.now(UTC) - tenant_join.updated_at = datetime.now(UTC) - db_session_with_containers.add(tenant_join) - db_session_with_containers.commit() - - return tenant, account - - -def _build_form(db_session_with_containers, tenant, account, *, app_id: str, workflow_execution_id: str): - delivery_method = EmailDeliveryMethod( - config=EmailDeliveryConfig( - recipients=EmailRecipients( - whole_workspace=False, - items=[ - MemberRecipient(user_id=account.id), - ExternalRecipient(email="external@example.com"), - ], - ), - subject="Action needed {{ node_title }} {{#node1.value#}}", - body="Token {{ form_token }} link {{#url#}} content {{#node1.value#}}", - ) - ) - - node_data = HumanInputNodeData( - title="Review", - form_content="Form content", - delivery_methods=[delivery_method], - ) - - engine = db_session_with_containers.get_bind() - repo = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) - params = FormCreateParams( - app_id=app_id, - workflow_execution_id=workflow_execution_id, - node_id="node-1", - form_config=node_data, - rendered_content="Rendered", - delivery_methods=node_data.delivery_methods, - display_in_ui=False, - resolved_default_values={}, - ) - return repo.create_form(params) - - -def _create_workflow_pause_state( - db_session_with_containers, - *, - workflow_run_id: str, - workflow_id: str, - tenant_id: str, - app_id: str, - account_id: str, - variable_pool: VariablePool, -): - workflow_run = WorkflowRun( - id=workflow_run_id, - tenant_id=tenant_id, - app_id=app_id, - workflow_id=workflow_id, - type=WorkflowType.WORKFLOW, - triggered_from=WorkflowRunTriggeredFrom.APP_RUN, - version="1", - graph="{}", - inputs="{}", - status=WorkflowExecutionStatus.PAUSED, - created_by_role=CreatorUserRole.ACCOUNT, - created_by=account_id, - created_at=datetime.now(UTC), - ) - db_session_with_containers.add(workflow_run) - - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) - resumption_context = WorkflowResumptionContext( - generate_entity={ - "type": AppMode.WORKFLOW, - "entity": WorkflowAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=WorkflowUIBasedAppConfig( - tenant_id=tenant_id, - app_id=app_id, - app_mode=AppMode.WORKFLOW, - workflow_id=workflow_id, - ), - inputs={}, - files=[], - user_id=account_id, - stream=False, - invoke_from=InvokeFrom.WEB_APP, - workflow_execution_id=workflow_run_id, - ), - }, - serialized_graph_runtime_state=runtime_state.dumps(), - ) - - state_object_key = f"workflow_pause_states/{workflow_run_id}.json" - storage.save(state_object_key, resumption_context.dumps().encode()) - - pause_state = WorkflowPause( - workflow_id=workflow_id, - workflow_run_id=workflow_run_id, - state_object_key=state_object_key, - ) - db_session_with_containers.add(pause_state) - db_session_with_containers.commit() - - -def test_dispatch_human_input_email_task_integration(monkeypatch: pytest.MonkeyPatch, db_session_with_containers): - tenant, account = _create_workspace_member(db_session_with_containers) - workflow_run_id = str(uuid.uuid4()) - workflow_id = str(uuid.uuid4()) - app_id = str(uuid.uuid4()) - variable_pool = VariablePool() - variable_pool.add(["node1", "value"], "OK") - _create_workflow_pause_state( - db_session_with_containers, - workflow_run_id=workflow_run_id, - workflow_id=workflow_id, - tenant_id=tenant.id, - app_id=app_id, - account_id=account.id, - variable_pool=variable_pool, - ) - form_entity = _build_form( - db_session_with_containers, - tenant, - account, - app_id=app_id, - workflow_execution_id=workflow_run_id, - ) - - monkeypatch.setattr(dify_config, "APP_WEB_URL", "https://app.example.com") - - with patch("tasks.mail_human_input_delivery_task.mail") as mock_mail: - mock_mail.is_inited.return_value = True - - dispatch_human_input_email_task(form_id=form_entity.id, node_title="Approval") - - assert mock_mail.send.call_count == 2 - send_args = [call.kwargs for call in mock_mail.send.call_args_list] - recipients = {kwargs["to"] for kwargs in send_args} - assert recipients == {"owner@example.com", "external@example.com"} - assert all(kwargs["subject"] == "Action needed {{ node_title }} {{#node1.value#}}" for kwargs in send_args) - assert all("app.example.com/form/" in kwargs["html"] for kwargs in send_args) - assert all("content OK" in kwargs["html"] for kwargs in send_args) - assert all("{{ form_token }}" in kwargs["html"] for kwargs in send_args) diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index 5f4f28cf4f..889e3d1d83 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -94,6 +94,11 @@ class PrunePausesTestCase: def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]: """Create test cases for pause workflow failure scenarios.""" return [ + PauseWorkflowFailureCase( + name="pause_already_paused_workflow", + initial_status=WorkflowExecutionStatus.PAUSED, + description="Should fail to pause an already paused workflow", + ), PauseWorkflowFailureCase( name="pause_completed_workflow", initial_status=WorkflowExecutionStatus.SUCCEEDED, diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index cf52980e57..6fce7849f9 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -164,62 +164,6 @@ def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): assert "timezone=UTC" in options -def test_pubsub_redis_url_default(monkeypatch: pytest.MonkeyPatch): - os.environ.clear() - - monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") - monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") - monkeypatch.setenv("DB_USERNAME", "postgres") - monkeypatch.setenv("DB_PASSWORD", "postgres") - monkeypatch.setenv("DB_HOST", "localhost") - monkeypatch.setenv("DB_PORT", "5432") - monkeypatch.setenv("DB_DATABASE", "dify") - monkeypatch.setenv("REDIS_HOST", "redis.example.com") - monkeypatch.setenv("REDIS_PORT", "6380") - monkeypatch.setenv("REDIS_USERNAME", "user") - monkeypatch.setenv("REDIS_PASSWORD", "pass@word") - monkeypatch.setenv("REDIS_DB", "2") - monkeypatch.setenv("REDIS_USE_SSL", "true") - - config = DifyConfig() - - assert config.normalized_pubsub_redis_url == "rediss://user:pass%40word@redis.example.com:6380/2" - assert config.PUBSUB_REDIS_CHANNEL_TYPE == "pubsub" - - -def test_pubsub_redis_url_override(monkeypatch: pytest.MonkeyPatch): - os.environ.clear() - - monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") - monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") - monkeypatch.setenv("DB_USERNAME", "postgres") - monkeypatch.setenv("DB_PASSWORD", "postgres") - monkeypatch.setenv("DB_HOST", "localhost") - monkeypatch.setenv("DB_PORT", "5432") - monkeypatch.setenv("DB_DATABASE", "dify") - monkeypatch.setenv("PUBSUB_REDIS_URL", "redis://pubsub-host:6381/5") - - config = DifyConfig() - - assert config.normalized_pubsub_redis_url == "redis://pubsub-host:6381/5" - - -def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest.MonkeyPatch): - os.environ.clear() - - monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") - monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") - monkeypatch.setenv("DB_USERNAME", "postgres") - monkeypatch.setenv("DB_PASSWORD", "postgres") - monkeypatch.setenv("DB_HOST", "localhost") - monkeypatch.setenv("DB_PORT", "5432") - monkeypatch.setenv("DB_DATABASE", "dify") - monkeypatch.setenv("REDIS_HOST", "") - - with pytest.raises(ValueError, match="PUBSUB_REDIS_URL must be set"): - _ = DifyConfig().normalized_pubsub_redis_url - - @pytest.mark.parametrize( ("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"), [ diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index da957d3a81..e3c1a617f7 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -51,8 +51,6 @@ def _patch_redis_clients_on_loaded_modules(): continue if hasattr(module, "redis_client"): module.redis_client = redis_mock - if hasattr(module, "pubsub_redis_client"): - module.pubsub_redis_client = redis_mock @pytest.fixture @@ -70,10 +68,7 @@ def _provide_app_context(app: Flask): def _patch_redis_clients(): """Patch redis_client to MagicMock only for unit test executions.""" - with ( - patch.object(ext_redis, "redis_client", redis_mock), - patch.object(ext_redis, "pubsub_redis_client", redis_mock), - ): + with patch.object(ext_redis, "redis_client", redis_mock): _patch_redis_clients_on_loaded_modules() yield diff --git a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py index 2ac3dc037d..c557605916 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py @@ -16,9 +16,11 @@ if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] -@pytest.fixture(scope="module") -def app_module(): +def _load_app_module(): module_name = "controllers.console.app.app" + if module_name in sys.modules: + return sys.modules[module_name] + root = Path(__file__).resolve().parents[5] module_path = root / "controllers" / "console" / "app" / "app.py" @@ -57,12 +59,8 @@ def app_module(): stub_namespace = _StubNamespace() - original_modules: dict[str, ModuleType | None] = { - "controllers.console": sys.modules.get("controllers.console"), - "controllers.console.app": sys.modules.get("controllers.console.app"), - "controllers.common.schema": sys.modules.get("controllers.common.schema"), - module_name: sys.modules.get(module_name), - } + original_console = sys.modules.get("controllers.console") + original_app_pkg = sys.modules.get("controllers.console.app") stubbed_modules: list[tuple[str, ModuleType | None]] = [] console_module = ModuleType("controllers.console") @@ -107,35 +105,35 @@ def app_module(): module = util.module_from_spec(spec) sys.modules[module_name] = module - assert spec.loader is not None - spec.loader.exec_module(module) - try: - yield module + assert spec.loader is not None + spec.loader.exec_module(module) finally: for name, original in reversed(stubbed_modules): if original is not None: sys.modules[name] = original else: sys.modules.pop(name, None) - for name, original in original_modules.items(): - if original is not None: - sys.modules[name] = original - else: - sys.modules.pop(name, None) + if original_console is not None: + sys.modules["controllers.console"] = original_console + else: + sys.modules.pop("controllers.console", None) + if original_app_pkg is not None: + sys.modules["controllers.console.app"] = original_app_pkg + else: + sys.modules.pop("controllers.console.app", None) + + return module -@pytest.fixture(scope="module") -def app_models(app_module): - return SimpleNamespace( - AppDetailWithSite=app_module.AppDetailWithSite, - AppPagination=app_module.AppPagination, - AppPartial=app_module.AppPartial, - ) +_app_module = _load_app_module() +AppDetailWithSite = _app_module.AppDetailWithSite +AppPagination = _app_module.AppPagination +AppPartial = _app_module.AppPartial @pytest.fixture(autouse=True) -def patch_signed_url(monkeypatch, app_module): +def patch_signed_url(monkeypatch): """Ensure icon URL generation uses a deterministic helper for tests.""" def _fake_signed_url(key: str | None) -> str | None: @@ -143,7 +141,7 @@ def patch_signed_url(monkeypatch, app_module): return None return f"signed:{key}" - monkeypatch.setattr(app_module.file_helpers, "get_signed_file_url", _fake_signed_url) + monkeypatch.setattr(_app_module.file_helpers, "get_signed_file_url", _fake_signed_url) def _ts(hour: int = 12) -> datetime: @@ -171,8 +169,7 @@ def _dummy_workflow(): ) -def test_app_partial_serialization_uses_aliases(app_models): - AppPartial = app_models.AppPartial +def test_app_partial_serialization_uses_aliases(): created_at = _ts() app_obj = SimpleNamespace( id="app-1", @@ -207,8 +204,7 @@ def test_app_partial_serialization_uses_aliases(app_models): assert serialized["tags"][0]["name"] == "Utilities" -def test_app_detail_with_site_includes_nested_serialization(app_models): - AppDetailWithSite = app_models.AppDetailWithSite +def test_app_detail_with_site_includes_nested_serialization(): timestamp = _ts(14) site = SimpleNamespace( code="site-code", @@ -257,8 +253,7 @@ def test_app_detail_with_site_includes_nested_serialization(app_models): assert serialized["site"]["created_at"] == int(timestamp.timestamp()) -def test_app_pagination_aliases_per_page_and_has_next(app_models): - AppPagination = app_models.AppPagination +def test_app_pagination_aliases_per_page_and_has_next(): item_one = SimpleNamespace( id="app-10", name="Paginated One", diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py deleted file mode 100644 index 86a3b2bd93..0000000000 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py +++ /dev/null @@ -1,229 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest -from flask import Flask -from pydantic import ValidationError - -from controllers.console import wraps as console_wraps -from controllers.console.app import workflow as workflow_module -from controllers.console.app import wraps as app_wraps -from libs import login as login_lib -from models.account import Account, AccountStatus, TenantAccountRole -from models.model import AppMode - - -def _make_account() -> Account: - account = Account(name="tester", email="tester@example.com") - account.status = AccountStatus.ACTIVE - account.role = TenantAccountRole.OWNER - account.id = "account-123" # type: ignore[assignment] - account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined] - account._get_current_object = lambda: account # type: ignore[attr-defined] - return account - - -def _make_app(mode: AppMode) -> SimpleNamespace: - return SimpleNamespace(id="app-123", tenant_id="tenant-123", mode=mode.value) - - -def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app_model: SimpleNamespace) -> None: - # Skip setup and auth guardrails - monkeypatch.setattr("configs.dify_config.EDITION", "CLOUD") - monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True) - monkeypatch.setattr(login_lib, "current_user", account) - monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) - monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None) - monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) - monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) - monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) - monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD") - monkeypatch.delenv("INIT_PASSWORD", raising=False) - - # Avoid hitting the database when resolving the app model - monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model) - - -@dataclass -class PreviewCase: - resource_cls: type - path: str - mode: AppMode - - -@pytest.mark.parametrize( - "case", - [ - PreviewCase( - resource_cls=workflow_module.AdvancedChatDraftHumanInputFormPreviewApi, - path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-42/form/preview", - mode=AppMode.ADVANCED_CHAT, - ), - PreviewCase( - resource_cls=workflow_module.WorkflowDraftHumanInputFormPreviewApi, - path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-42/form/preview", - mode=AppMode.WORKFLOW, - ), - ], -) -def test_human_input_preview_delegates_to_service( - app: Flask, monkeypatch: pytest.MonkeyPatch, case: PreviewCase -) -> None: - account = _make_account() - app_model = _make_app(case.mode) - _patch_console_guards(monkeypatch, account, app_model) - - preview_payload = { - "form_id": "node-42", - "form_content": "
example
", - "inputs": [{"name": "topic"}], - "actions": [{"id": "continue"}], - } - service_instance = MagicMock() - service_instance.get_human_input_form_preview.return_value = preview_payload - monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance)) - - with app.test_request_context(case.path, method="POST", json={"inputs": {"topic": "tech"}}): - response = case.resource_cls().post(app_id=app_model.id, node_id="node-42") - - assert response == preview_payload - service_instance.get_human_input_form_preview.assert_called_once_with( - app_model=app_model, - account=account, - node_id="node-42", - inputs={"topic": "tech"}, - ) - - -@dataclass -class SubmitCase: - resource_cls: type - path: str - mode: AppMode - - -@pytest.mark.parametrize( - "case", - [ - SubmitCase( - resource_cls=workflow_module.AdvancedChatDraftHumanInputFormRunApi, - path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-99/form/run", - mode=AppMode.ADVANCED_CHAT, - ), - SubmitCase( - resource_cls=workflow_module.WorkflowDraftHumanInputFormRunApi, - path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-99/form/run", - mode=AppMode.WORKFLOW, - ), - ], -) -def test_human_input_submit_forwards_payload(app: Flask, monkeypatch: pytest.MonkeyPatch, case: SubmitCase) -> None: - account = _make_account() - app_model = _make_app(case.mode) - _patch_console_guards(monkeypatch, account, app_model) - - result_payload = {"node_id": "node-99", "outputs": {"__rendered_content": "

done

"}, "action": "approve"} - service_instance = MagicMock() - service_instance.submit_human_input_form_preview.return_value = result_payload - monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance)) - - with app.test_request_context( - case.path, - method="POST", - json={"form_inputs": {"answer": "42"}, "inputs": {"#node-1.result#": "LLM output"}, "action": "approve"}, - ): - response = case.resource_cls().post(app_id=app_model.id, node_id="node-99") - - assert response == result_payload - service_instance.submit_human_input_form_preview.assert_called_once_with( - app_model=app_model, - account=account, - node_id="node-99", - form_inputs={"answer": "42"}, - inputs={"#node-1.result#": "LLM output"}, - action="approve", - ) - - -@dataclass -class DeliveryTestCase: - resource_cls: type - path: str - mode: AppMode - - -@pytest.mark.parametrize( - "case", - [ - DeliveryTestCase( - resource_cls=workflow_module.WorkflowDraftHumanInputDeliveryTestApi, - path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-7/delivery-test", - mode=AppMode.ADVANCED_CHAT, - ), - DeliveryTestCase( - resource_cls=workflow_module.WorkflowDraftHumanInputDeliveryTestApi, - path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-7/delivery-test", - mode=AppMode.WORKFLOW, - ), - ], -) -def test_human_input_delivery_test_calls_service( - app: Flask, monkeypatch: pytest.MonkeyPatch, case: DeliveryTestCase -) -> None: - account = _make_account() - app_model = _make_app(case.mode) - _patch_console_guards(monkeypatch, account, app_model) - - service_instance = MagicMock() - monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance)) - - with app.test_request_context( - case.path, - method="POST", - json={"delivery_method_id": "delivery-123"}, - ): - response = case.resource_cls().post(app_id=app_model.id, node_id="node-7") - - assert response == {} - service_instance.test_human_input_delivery.assert_called_once_with( - app_model=app_model, - account=account, - node_id="node-7", - delivery_method_id="delivery-123", - inputs={}, - ) - - -def test_human_input_delivery_test_maps_validation_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: - account = _make_account() - app_model = _make_app(AppMode.ADVANCED_CHAT) - _patch_console_guards(monkeypatch, account, app_model) - - service_instance = MagicMock() - service_instance.test_human_input_delivery.side_effect = ValueError("bad delivery method") - monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance)) - - with app.test_request_context( - "/console/api/apps/app-123/workflows/draft/human-input/nodes/node-1/delivery-test", - method="POST", - json={"delivery_method_id": "bad"}, - ): - with pytest.raises(ValueError): - workflow_module.WorkflowDraftHumanInputDeliveryTestApi().post(app_id=app_model.id, node_id="node-1") - - -def test_human_input_preview_rejects_non_mapping(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: - account = _make_account() - app_model = _make_app(AppMode.ADVANCED_CHAT) - _patch_console_guards(monkeypatch, account, app_model) - - with app.test_request_context( - "/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-1/form/preview", - method="POST", - json={"inputs": ["not-a-dict"]}, - ): - with pytest.raises(ValidationError): - workflow_module.AdvancedChatDraftHumanInputFormPreviewApi().post(app_id=app_model.id, node_id="node-1") diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py deleted file mode 100644 index 34d6a2232c..0000000000 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ /dev/null @@ -1,91 +0,0 @@ -from __future__ import annotations - -from datetime import datetime -from types import SimpleNamespace -from unittest.mock import Mock - -import pytest -from flask import Flask - -from controllers.console import wraps as console_wraps -from controllers.console.app import workflow_run as workflow_run_module -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.nodes.human_input.entities import FormInput, UserAction -from core.workflow.nodes.human_input.enums import FormInputType -from libs import login as login_lib -from models.account import Account, AccountStatus, TenantAccountRole -from models.workflow import WorkflowRun - - -def _make_account() -> Account: - account = Account(name="tester", email="tester@example.com") - account.status = AccountStatus.ACTIVE - account.role = TenantAccountRole.OWNER - account.id = "account-123" # type: ignore[assignment] - account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined] - account._get_current_object = lambda: account # type: ignore[attr-defined] - return account - - -def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account) -> None: - monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True) - monkeypatch.setattr(login_lib, "current_user", account) - monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) - monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None) - monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) - monkeypatch.setattr(workflow_run_module, "current_user", account) - monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD") - - -class _PauseEntity: - def __init__(self, paused_at: datetime, reasons: list[HumanInputRequired]): - self.paused_at = paused_at - self._reasons = reasons - - def get_pause_reasons(self): - return self._reasons - - -def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: - account = _make_account() - _patch_console_guards(monkeypatch, account) - monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com") - - workflow_run = Mock(spec=WorkflowRun) - workflow_run.status = WorkflowExecutionStatus.PAUSED - workflow_run.created_at = datetime(2024, 1, 1, 12, 0, 0) - fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run)) - monkeypatch.setattr(workflow_run_module, "db", fake_db) - - reason = HumanInputRequired( - form_id="form-1", - form_content="content", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], - actions=[UserAction(id="approve", title="Approve")], - node_id="node-1", - node_title="Ask Name", - form_token="backstage-token", - ) - pause_entity = _PauseEntity(paused_at=datetime(2024, 1, 1, 12, 0, 0), reasons=[reason]) - - repo = Mock() - repo.get_workflow_pause.return_value = pause_entity - monkeypatch.setattr( - workflow_run_module.DifyAPIRepositoryFactory, - "create_api_workflow_run_repository", - lambda *_, **__: repo, - ) - - with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): - response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") - - assert status == 200 - assert response["paused_at"] == "2024-01-01T12:00:00Z" - assert response["paused_nodes"][0]["node_id"] == "node-1" - assert response["paused_nodes"][0]["pause_type"]["type"] == "human_input" - assert ( - response["paused_nodes"][0]["pause_type"]["backstage_input_url"] - == "https://web.example.com/form/backstage-token" - ) - assert "pending_human_inputs" not in response diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py deleted file mode 100644 index fcaa61a871..0000000000 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py +++ /dev/null @@ -1,25 +0,0 @@ -from types import SimpleNamespace - -from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField -from core.workflow.enums import WorkflowExecutionStatus - - -def test_workflow_run_status_field_with_enum() -> None: - field = WorkflowRunStatusField() - obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED) - - assert field.output("status", obj) == "paused" - - -def test_workflow_run_outputs_field_paused_returns_empty() -> None: - field = WorkflowRunOutputsField() - obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED, outputs_dict={"foo": "bar"}) - - assert field.output("outputs", obj) == {} - - -def test_workflow_run_outputs_field_running_returns_outputs() -> None: - field = WorkflowRunOutputsField() - obj = SimpleNamespace(status=WorkflowExecutionStatus.RUNNING, outputs_dict={"foo": "bar"}) - - assert field.output("outputs", obj) == {"foo": "bar"} diff --git a/api/tests/unit_tests/controllers/web/test_human_input_form.py b/api/tests/unit_tests/controllers/web/test_human_input_form.py deleted file mode 100644 index 4fb735b033..0000000000 --- a/api/tests/unit_tests/controllers/web/test_human_input_form.py +++ /dev/null @@ -1,456 +0,0 @@ -"""Unit tests for controllers.web.human_input_form endpoints.""" - -from __future__ import annotations - -import json -from datetime import UTC, datetime -from types import SimpleNamespace -from typing import Any -from unittest.mock import MagicMock - -import pytest -from flask import Flask -from werkzeug.exceptions import Forbidden - -import controllers.web.human_input_form as human_input_module -import controllers.web.site as site_module -from controllers.web.error import WebFormRateLimitExceededError -from models.human_input import RecipientType -from services.human_input_service import FormExpiredError - -HumanInputFormApi = human_input_module.HumanInputFormApi -TenantStatus = human_input_module.TenantStatus - - -@pytest.fixture -def app() -> Flask: - """Configure a minimal Flask app for request contexts.""" - - app = Flask(__name__) - app.config["TESTING"] = True - return app - - -class _FakeSession: - """Simple stand-in for db.session that returns pre-seeded objects.""" - - def __init__(self, mapping: dict[str, Any]): - self._mapping = mapping - self._model_name: str | None = None - - def query(self, model): - self._model_name = model.__name__ - return self - - def where(self, *args, **kwargs): - return self - - def first(self): - assert self._model_name is not None - return self._mapping.get(self._model_name) - - -class _FakeDB: - """Minimal db stub exposing engine and session.""" - - def __init__(self, session: _FakeSession): - self.session = session - self.engine = object() - - -def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask): - """GET returns form definition merged with site payload.""" - - expiration_time = datetime(2099, 1, 1, tzinfo=UTC) - - class _FakeDefinition: - def model_dump(self): - return { - "form_content": "Raw content", - "rendered_content": "Rendered {{#$output.name#}}", - "inputs": [{"type": "text", "output_variable_name": "name", "default": None}], - "default_values": {"name": "Alice", "age": 30, "meta": {"k": "v"}}, - "user_actions": [{"id": "approve", "title": "Approve", "button_style": "default"}], - } - - class _FakeForm: - def __init__(self, expiration: datetime): - self.workflow_run_id = "workflow-1" - self.app_id = "app-1" - self.tenant_id = "tenant-1" - self.expiration_time = expiration - self.recipient_type = RecipientType.BACKSTAGE - - def get_definition(self): - return _FakeDefinition() - - form = _FakeForm(expiration_time) - limiter_mock = MagicMock() - limiter_mock.is_rate_limited.return_value = False - monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) - monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") - - tenant = SimpleNamespace( - id="tenant-1", - status=TenantStatus.NORMAL, - plan="basic", - custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": False}, - ) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) - workflow_run = SimpleNamespace(app_id="app-1") - site_model = SimpleNamespace( - title="My Site", - icon_type="emoji", - icon="robot", - icon_background="#fff", - description="desc", - default_language="en", - chat_color_theme="light", - chat_color_theme_inverted=False, - copyright=None, - privacy_policy=None, - custom_disclaimer=None, - prompt_public=False, - show_workflow_steps=True, - use_icon_as_answer_icon=False, - ) - - # Patch service to return fake form. - service_mock = MagicMock() - service_mock.get_form_by_token.return_value = form - monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) - - # Patch db session. - db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": site_model})) - monkeypatch.setattr(human_input_module, "db", db_stub) - - monkeypatch.setattr( - site_module.FeatureService, - "get_features", - lambda tenant_id: SimpleNamespace(can_replace_logo=True), - ) - - with app.test_request_context("/api/form/human_input/token-1", method="GET"): - response = HumanInputFormApi().get("token-1") - - body = json.loads(response.get_data(as_text=True)) - assert set(body.keys()) == { - "site", - "form_content", - "inputs", - "resolved_default_values", - "user_actions", - "expiration_time", - } - assert body["form_content"] == "Rendered {{#$output.name#}}" - assert body["inputs"] == [{"type": "text", "output_variable_name": "name", "default": None}] - assert body["resolved_default_values"] == {"name": "Alice", "age": "30", "meta": '{"k": "v"}'} - assert body["user_actions"] == [{"id": "approve", "title": "Approve", "button_style": "default"}] - assert body["expiration_time"] == int(expiration_time.timestamp()) - assert body["site"] == { - "app_id": "app-1", - "end_user_id": None, - "enable_site": True, - "site": { - "title": "My Site", - "chat_color_theme": "light", - "chat_color_theme_inverted": False, - "icon_type": "emoji", - "icon": "robot", - "icon_background": "#fff", - "icon_url": None, - "description": "desc", - "copyright": None, - "privacy_policy": None, - "custom_disclaimer": None, - "default_language": "en", - "prompt_public": False, - "show_workflow_steps": True, - "use_icon_as_answer_icon": False, - }, - "model_config": None, - "plan": "basic", - "can_replace_logo": True, - "custom_config": { - "remove_webapp_brand": True, - "replace_webapp_logo": None, - }, - } - service_mock.get_form_by_token.assert_called_once_with("token-1") - limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") - limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") - - -def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: Flask): - """GET returns form payload for backstage token.""" - - expiration_time = datetime(2099, 1, 2, tzinfo=UTC) - - class _FakeDefinition: - def model_dump(self): - return { - "form_content": "Raw content", - "rendered_content": "Rendered", - "inputs": [], - "default_values": {}, - "user_actions": [], - } - - class _FakeForm: - def __init__(self, expiration: datetime): - self.workflow_run_id = "workflow-1" - self.app_id = "app-1" - self.tenant_id = "tenant-1" - self.expiration_time = expiration - - def get_definition(self): - return _FakeDefinition() - - form = _FakeForm(expiration_time) - limiter_mock = MagicMock() - limiter_mock.is_rate_limited.return_value = False - monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) - monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") - tenant = SimpleNamespace( - id="tenant-1", - status=TenantStatus.NORMAL, - plan="basic", - custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": False}, - ) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) - workflow_run = SimpleNamespace(app_id="app-1") - site_model = SimpleNamespace( - title="My Site", - icon_type="emoji", - icon="robot", - icon_background="#fff", - description="desc", - default_language="en", - chat_color_theme="light", - chat_color_theme_inverted=False, - copyright=None, - privacy_policy=None, - custom_disclaimer=None, - prompt_public=False, - show_workflow_steps=True, - use_icon_as_answer_icon=False, - ) - - service_mock = MagicMock() - service_mock.get_form_by_token.return_value = form - monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) - - db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": site_model})) - monkeypatch.setattr(human_input_module, "db", db_stub) - - monkeypatch.setattr( - site_module.FeatureService, - "get_features", - lambda tenant_id: SimpleNamespace(can_replace_logo=True), - ) - - with app.test_request_context("/api/form/human_input/token-1", method="GET"): - response = HumanInputFormApi().get("token-1") - - body = json.loads(response.get_data(as_text=True)) - assert set(body.keys()) == { - "site", - "form_content", - "inputs", - "resolved_default_values", - "user_actions", - "expiration_time", - } - assert body["form_content"] == "Rendered" - assert body["inputs"] == [] - assert body["resolved_default_values"] == {} - assert body["user_actions"] == [] - assert body["expiration_time"] == int(expiration_time.timestamp()) - assert body["site"] == { - "app_id": "app-1", - "end_user_id": None, - "enable_site": True, - "site": { - "title": "My Site", - "chat_color_theme": "light", - "chat_color_theme_inverted": False, - "icon_type": "emoji", - "icon": "robot", - "icon_background": "#fff", - "icon_url": None, - "description": "desc", - "copyright": None, - "privacy_policy": None, - "custom_disclaimer": None, - "default_language": "en", - "prompt_public": False, - "show_workflow_steps": True, - "use_icon_as_answer_icon": False, - }, - "model_config": None, - "plan": "basic", - "can_replace_logo": True, - "custom_config": { - "remove_webapp_brand": True, - "replace_webapp_logo": None, - }, - } - service_mock.get_form_by_token.assert_called_once_with("token-1") - limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") - limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") - - -def test_get_form_raises_forbidden_when_site_missing(monkeypatch: pytest.MonkeyPatch, app: Flask): - """GET raises Forbidden if site cannot be resolved.""" - - expiration_time = datetime(2099, 1, 3, tzinfo=UTC) - - class _FakeDefinition: - def model_dump(self): - return { - "form_content": "Raw content", - "rendered_content": "Rendered", - "inputs": [], - "default_values": {}, - "user_actions": [], - } - - class _FakeForm: - def __init__(self, expiration: datetime): - self.workflow_run_id = "workflow-1" - self.app_id = "app-1" - self.tenant_id = "tenant-1" - self.expiration_time = expiration - - def get_definition(self): - return _FakeDefinition() - - form = _FakeForm(expiration_time) - limiter_mock = MagicMock() - limiter_mock.is_rate_limited.return_value = False - monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) - monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") - tenant = SimpleNamespace(status=TenantStatus.NORMAL) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant) - workflow_run = SimpleNamespace(app_id="app-1") - - service_mock = MagicMock() - service_mock.get_form_by_token.return_value = form - monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) - - db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": None})) - monkeypatch.setattr(human_input_module, "db", db_stub) - - with app.test_request_context("/api/form/human_input/token-1", method="GET"): - with pytest.raises(Forbidden): - HumanInputFormApi().get("token-1") - limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") - limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") - - -def test_submit_form_accepts_backstage_token(monkeypatch: pytest.MonkeyPatch, app: Flask): - """POST forwards backstage submissions to the service.""" - - class _FakeForm: - recipient_type = RecipientType.BACKSTAGE - - form = _FakeForm() - limiter_mock = MagicMock() - limiter_mock.is_rate_limited.return_value = False - monkeypatch.setattr(human_input_module, "_FORM_SUBMIT_RATE_LIMITER", limiter_mock) - monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") - service_mock = MagicMock() - service_mock.get_form_by_token.return_value = form - monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) - monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({}))) - - with app.test_request_context( - "/api/form/human_input/token-1", - method="POST", - json={"inputs": {"content": "ok"}, "action": "approve"}, - ): - response, status = HumanInputFormApi().post("token-1") - - assert status == 200 - assert response == {} - service_mock.submit_form_by_token.assert_called_once_with( - recipient_type=RecipientType.BACKSTAGE, - form_token="token-1", - selected_action_id="approve", - form_data={"content": "ok"}, - submission_end_user_id=None, - ) - limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") - limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") - - -def test_submit_form_rate_limited(monkeypatch: pytest.MonkeyPatch, app: Flask): - """POST rejects submissions when rate limit is exceeded.""" - - limiter_mock = MagicMock() - limiter_mock.is_rate_limited.return_value = True - monkeypatch.setattr(human_input_module, "_FORM_SUBMIT_RATE_LIMITER", limiter_mock) - monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") - - service_mock = MagicMock() - service_mock.get_form_by_token.return_value = None - monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) - monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({}))) - - with app.test_request_context( - "/api/form/human_input/token-1", - method="POST", - json={"inputs": {"content": "ok"}, "action": "approve"}, - ): - with pytest.raises(WebFormRateLimitExceededError): - HumanInputFormApi().post("token-1") - - limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") - limiter_mock.increment_rate_limit.assert_not_called() - service_mock.get_form_by_token.assert_not_called() - - -def test_get_form_rate_limited(monkeypatch: pytest.MonkeyPatch, app: Flask): - """GET rejects requests when rate limit is exceeded.""" - - limiter_mock = MagicMock() - limiter_mock.is_rate_limited.return_value = True - monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) - monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") - - service_mock = MagicMock() - service_mock.get_form_by_token.return_value = None - monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) - monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({}))) - - with app.test_request_context("/api/form/human_input/token-1", method="GET"): - with pytest.raises(WebFormRateLimitExceededError): - HumanInputFormApi().get("token-1") - - limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") - limiter_mock.increment_rate_limit.assert_not_called() - service_mock.get_form_by_token.assert_not_called() - - -def test_get_form_raises_expired(monkeypatch: pytest.MonkeyPatch, app: Flask): - class _FakeForm: - pass - - form = _FakeForm() - limiter_mock = MagicMock() - limiter_mock.is_rate_limited.return_value = False - monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) - monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") - service_mock = MagicMock() - service_mock.get_form_by_token.return_value = form - service_mock.ensure_form_active.side_effect = FormExpiredError("form-id") - monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) - monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({}))) - - with app.test_request_context("/api/form/human_input/token-1", method="GET"): - with pytest.raises(FormExpiredError): - HumanInputFormApi().get("token-1") - - service_mock.ensure_form_active.assert_called_once_with(form) - limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") - limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") diff --git a/api/tests/unit_tests/controllers/web/test_message_list.py b/api/tests/unit_tests/controllers/web/test_message_list.py index 1c096bfbcf..2835f7ffbf 100644 --- a/api/tests/unit_tests/controllers/web/test_message_list.py +++ b/api/tests/unit_tests/controllers/web/test_message_list.py @@ -3,7 +3,6 @@ from __future__ import annotations import builtins -import uuid from datetime import datetime from types import ModuleType, SimpleNamespace from unittest.mock import patch @@ -13,8 +12,6 @@ import pytest from flask import Flask from flask.views import MethodView -from core.entities.execution_extra_content import HumanInputContent - # Ensure flask_restx.api finds MethodView during import. if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] @@ -140,12 +137,6 @@ def test_message_list_mapping(app: Flask) -> None: status="success", error=None, message_metadata_dict={"meta": "value"}, - extra_contents=[ - HumanInputContent( - workflow_run_id=str(uuid.uuid4()), - submitted=True, - ) - ], ) pagination = SimpleNamespace(limit=20, has_more=False, data=[message]) @@ -178,8 +169,6 @@ def test_message_list_mapping(app: Flask) -> None: assert item["agent_thoughts"][0]["chain_id"] == "chain-1" assert item["agent_thoughts"][0]["created_at"] == int(thought_created_at.timestamp()) - assert item["extra_contents"][0]["workflow_run_id"] == message.extra_contents[0].workflow_run_id - assert item["extra_contents"][0]["submitted"] == message.extra_contents[0].submitted assert item["message_files"][0]["id"] == "file-dict" assert item["message_files"][1]["id"] == "file-obj" diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py deleted file mode 100644 index a94b5445f7..0000000000 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py +++ /dev/null @@ -1,187 +0,0 @@ -from __future__ import annotations - -from contextlib import contextmanager -from datetime import datetime -from types import SimpleNamespace -from unittest import mock - -import pytest - -from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent -from core.workflow.entities.pause_reason import HumanInputRequired -from models.enums import MessageStatus -from models.execution_extra_content import HumanInputContent -from models.model import EndUser - - -def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline: - pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline.__new__( - pipeline_module.AdvancedChatAppGenerateTaskPipeline - ) - pipeline._workflow_run_id = "run-1" - pipeline._message_id = "message-1" - pipeline._workflow_tenant_id = "tenant-1" - return pipeline - - -def test_persist_human_input_extra_content_adds_record(monkeypatch: pytest.MonkeyPatch) -> None: - pipeline = _build_pipeline() - monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1") - - captured_session: dict[str, mock.Mock] = {} - - @contextmanager - def fake_session(): - session = mock.Mock() - session.scalar.return_value = None - captured_session["session"] = session - yield session - - pipeline._database_session = fake_session # type: ignore[method-assign] - - pipeline._persist_human_input_extra_content(node_id="node-1") - - session = captured_session["session"] - session.add.assert_called_once() - content = session.add.call_args.args[0] - assert isinstance(content, HumanInputContent) - assert content.workflow_run_id == "run-1" - assert content.message_id == "message-1" - assert content.form_id == "form-1" - - -def test_persist_human_input_extra_content_skips_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None: - pipeline = _build_pipeline() - monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: None) - - called = {"value": False} - - @contextmanager - def fake_session(): - called["value"] = True - session = mock.Mock() - yield session - - pipeline._database_session = fake_session # type: ignore[method-assign] - - pipeline._persist_human_input_extra_content(node_id="node-1") - - assert called["value"] is False - - -def test_persist_human_input_extra_content_skips_when_existing(monkeypatch: pytest.MonkeyPatch) -> None: - pipeline = _build_pipeline() - monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1") - - captured_session: dict[str, mock.Mock] = {} - - @contextmanager - def fake_session(): - session = mock.Mock() - session.scalar.return_value = HumanInputContent( - workflow_run_id="run-1", - message_id="message-1", - form_id="form-1", - ) - captured_session["session"] = session - yield session - - pipeline._database_session = fake_session # type: ignore[method-assign] - - pipeline._persist_human_input_extra_content(node_id="node-1") - - session = captured_session["session"] - session.add.assert_not_called() - - -def test_handle_workflow_paused_event_persists_human_input_extra_content() -> None: - pipeline = _build_pipeline() - pipeline._application_generate_entity = SimpleNamespace(task_id="task-1") - pipeline._workflow_response_converter = mock.Mock() - pipeline._workflow_response_converter.workflow_pause_to_stream_response.return_value = [] - pipeline._ensure_graph_runtime_initialized = mock.Mock( - return_value=SimpleNamespace( - total_tokens=0, - node_run_steps=0, - ), - ) - pipeline._save_message = mock.Mock() - message = SimpleNamespace(status=MessageStatus.NORMAL) - pipeline._get_message = mock.Mock(return_value=message) - pipeline._persist_human_input_extra_content = mock.Mock() - pipeline._base_task_pipeline = mock.Mock() - pipeline._base_task_pipeline.queue_manager = mock.Mock() - pipeline._message_saved_on_pause = False - - @contextmanager - def fake_session(): - session = mock.Mock() - yield session - - pipeline._database_session = fake_session # type: ignore[method-assign] - - reason = HumanInputRequired( - form_id="form-1", - form_content="content", - inputs=[], - actions=[], - node_id="node-1", - node_title="Approval", - form_token="token-1", - resolved_default_values={}, - ) - event = QueueWorkflowPausedEvent(reasons=[reason], outputs={}, paused_nodes=["node-1"]) - - list(pipeline._handle_workflow_paused_event(event)) - - pipeline._persist_human_input_extra_content.assert_called_once_with(form_id="form-1", node_id="node-1") - assert message.status == MessageStatus.PAUSED - - -def test_resume_appends_chunks_to_paused_answer() -> None: - app_config = SimpleNamespace(app_id="app-1", tenant_id="tenant-1", sensitive_word_avoidance=None) - application_generate_entity = SimpleNamespace( - app_config=app_config, - files=[], - workflow_run_id="run-1", - query="hello", - invoke_from=InvokeFrom.WEB_APP, - inputs={}, - task_id="task-1", - ) - queue_manager = SimpleNamespace(graph_runtime_state=None) - conversation = SimpleNamespace(id="conversation-1", mode="advanced-chat") - message = SimpleNamespace( - id="message-1", - created_at=datetime(2024, 1, 1), - query="hello", - answer="before", - status=MessageStatus.PAUSED, - ) - user = EndUser() - user.id = "user-1" - user.session_id = "session-1" - workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={}) - - pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline( - application_generate_entity=application_generate_entity, - workflow=workflow, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=True, - dialogue_count=1, - draft_var_saver_factory=SimpleNamespace(), - ) - - pipeline._get_message = mock.Mock(return_value=message) - pipeline._recorded_files = [] - - list(pipeline._handle_text_chunk_event(QueueTextChunkEvent(text="after"))) - pipeline._save_message(session=mock.Mock()) - - assert message.answer == "beforeafter" - assert message.status == MessageStatus.NORMAL diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py deleted file mode 100644 index 1c36b4d12b..0000000000 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ /dev/null @@ -1,87 +0,0 @@ -from datetime import UTC, datetime -from types import SimpleNamespace - -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable - - -def _build_converter(): - system_variables = SystemVariable( - files=[], - user_id="user-1", - app_id="app-1", - workflow_id="wf-1", - workflow_execution_id="run-1", - ) - runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) - app_entity = SimpleNamespace( - task_id="task-1", - app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"), - invoke_from=InvokeFrom.EXPLORE, - files=[], - inputs={}, - workflow_execution_id="run-1", - call_depth=0, - ) - account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com") - return WorkflowResponseConverter( - application_generate_entity=app_entity, - user=account, - system_variables=system_variables, - ) - - -def test_human_input_form_filled_stream_response_contains_rendered_content(): - converter = _build_converter() - converter.workflow_start_to_stream_response( - task_id="task-1", - workflow_run_id="run-1", - workflow_id="wf-1", - reason=WorkflowStartReason.INITIAL, - ) - - queue_event = QueueHumanInputFormFilledEvent( - node_execution_id="exec-1", - node_id="node-1", - node_type="human-input", - node_title="Human Input", - rendered_content="# Title\nvalue", - action_id="Approve", - action_text="Approve", - ) - - resp = converter.human_input_form_filled_to_stream_response(event=queue_event, task_id="task-1") - - assert resp.workflow_run_id == "run-1" - assert resp.data.node_id == "node-1" - assert resp.data.node_title == "Human Input" - assert resp.data.rendered_content.startswith("# Title") - assert resp.data.action_id == "Approve" - - -def test_human_input_form_timeout_stream_response_contains_timeout_metadata(): - converter = _build_converter() - converter.workflow_start_to_stream_response( - task_id="task-1", - workflow_run_id="run-1", - workflow_id="wf-1", - reason=WorkflowStartReason.INITIAL, - ) - - queue_event = QueueHumanInputFormTimeoutEvent( - node_id="node-1", - node_type="human-input", - node_title="Human Input", - expiration_time=datetime(2025, 1, 1, tzinfo=UTC), - ) - - resp = converter.human_input_form_timeout_to_stream_response(event=queue_event, task_id="task-1") - - assert resp.workflow_run_id == "run-1" - assert resp.data.node_id == "node-1" - assert resp.data.node_title == "Human Input" - assert resp.data.expiration_time == 1735689600 diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py deleted file mode 100644 index 0a9794e41c..0000000000 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ /dev/null @@ -1,56 +0,0 @@ -from types import SimpleNamespace - -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable - - -def _build_converter() -> WorkflowResponseConverter: - """Construct a minimal WorkflowResponseConverter for testing.""" - system_variables = SystemVariable( - files=[], - user_id="user-1", - app_id="app-1", - workflow_id="wf-1", - workflow_execution_id="run-1", - ) - runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) - app_entity = SimpleNamespace( - task_id="task-1", - app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"), - invoke_from=InvokeFrom.EXPLORE, - files=[], - inputs={}, - workflow_execution_id="run-1", - call_depth=0, - ) - account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com") - return WorkflowResponseConverter( - application_generate_entity=app_entity, - user=account, - system_variables=system_variables, - ) - - -def test_workflow_start_stream_response_carries_resumption_reason(): - converter = _build_converter() - resp = converter.workflow_start_to_stream_response( - task_id="task-1", - workflow_run_id="run-1", - workflow_id="wf-1", - reason=WorkflowStartReason.RESUMPTION, - ) - assert resp.data.reason is WorkflowStartReason.RESUMPTION - - -def test_workflow_start_stream_response_carries_initial_reason(): - converter = _build_converter() - resp = converter.workflow_start_to_stream_response( - task_id="task-1", - workflow_run_id="run-1", - workflow_id="wf-1", - reason=WorkflowStartReason.INITIAL, - ) - assert resp.data.reason is WorkflowStartReason.INITIAL diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index d25bff92dc..6b40bf462b 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -23,7 +23,6 @@ from core.app.entities.queue_entities import ( QueueNodeStartedEvent, QueueNodeSucceededEvent, ) -from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import NodeType from core.workflow.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now @@ -125,12 +124,7 @@ class TestWorkflowResponseConverter: original_data = {"large_field": "x" * 10000, "metadata": "info"} truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"} - converter.workflow_start_to_stream_response( - task_id="bootstrap", - workflow_run_id="run-id", - workflow_id="wf-id", - reason=WorkflowStartReason.INITIAL, - ) + converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -166,12 +160,7 @@ class TestWorkflowResponseConverter: original_data = {"small": "data"} - converter.workflow_start_to_stream_response( - task_id="bootstrap", - workflow_run_id="run-id", - workflow_id="wf-id", - reason=WorkflowStartReason.INITIAL, - ) + converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -202,12 +191,7 @@ class TestWorkflowResponseConverter: """Test node finish response when process_data is None.""" converter = self.create_workflow_response_converter() - converter.workflow_start_to_stream_response( - task_id="bootstrap", - workflow_run_id="run-id", - workflow_id="wf-id", - reason=WorkflowStartReason.INITIAL, - ) + converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -241,12 +225,7 @@ class TestWorkflowResponseConverter: original_data = {"large_field": "x" * 10000, "metadata": "info"} truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"} - converter.workflow_start_to_stream_response( - task_id="bootstrap", - workflow_run_id="run-id", - workflow_id="wf-id", - reason=WorkflowStartReason.INITIAL, - ) + converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -282,12 +261,7 @@ class TestWorkflowResponseConverter: original_data = {"small": "data"} - converter.workflow_start_to_stream_response( - task_id="bootstrap", - workflow_run_id="run-id", - workflow_id="wf-id", - reason=WorkflowStartReason.INITIAL, - ) + converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -426,7 +400,6 @@ class TestWorkflowResponseConverterServiceApiTruncation: task_id="test-task-id", workflow_run_id="test-workflow-run-id", workflow_id="test-workflow-id", - reason=WorkflowStartReason.INITIAL, ) return converter diff --git a/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py deleted file mode 100644 index f0d9afc0db..0000000000 --- a/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py +++ /dev/null @@ -1,139 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig -from core.app.apps import message_based_app_generator -from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom -from core.app.task_pipeline import message_cycle_manager -from core.app.task_pipeline.message_cycle_manager import MessageCycleManager -from models.model import AppMode, Conversation, Message - - -def _make_app_config() -> WorkflowUIBasedAppConfig: - return WorkflowUIBasedAppConfig( - tenant_id="tenant-id", - app_id="app-id", - app_mode=AppMode.ADVANCED_CHAT, - workflow_id="workflow-id", - additional_features=AppAdditionalFeatures(), - variables=[], - ) - - -def _make_generate_entity(app_config: WorkflowUIBasedAppConfig) -> AdvancedChatAppGenerateEntity: - return AdvancedChatAppGenerateEntity( - task_id="task-id", - app_config=app_config, - file_upload_config=None, - conversation_id=None, - inputs={}, - query="hello", - files=[], - parent_message_id=None, - user_id="user-id", - stream=True, - invoke_from=InvokeFrom.WEB_APP, - extras={}, - workflow_run_id="workflow-run-id", - ) - - -@pytest.fixture(autouse=True) -def _mock_db_session(monkeypatch): - session = MagicMock() - - def refresh_side_effect(obj): - if isinstance(obj, Conversation) and obj.id is None: - obj.id = "generated-conversation-id" - if isinstance(obj, Message) and obj.id is None: - obj.id = "generated-message-id" - - session.refresh.side_effect = refresh_side_effect - session.add.return_value = None - session.commit.return_value = None - - monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session)) - return session - - -def test_init_generate_records_sets_conversation_metadata(): - app_config = _make_app_config() - entity = _make_generate_entity(app_config) - - generator = AdvancedChatAppGenerator() - - conversation, _ = generator._init_generate_records(entity, conversation=None) - - assert entity.conversation_id == "generated-conversation-id" - assert conversation.id == "generated-conversation-id" - assert entity.is_new_conversation is True - - -def test_init_generate_records_marks_existing_conversation(): - app_config = _make_app_config() - entity = _make_generate_entity(app_config) - - existing_conversation = Conversation( - app_id=app_config.app_id, - app_model_config_id=None, - model_provider=None, - override_model_configs=None, - model_id=None, - mode=app_config.app_mode.value, - name="existing", - inputs={}, - introduction="", - system_instruction="", - system_instruction_tokens=0, - status="normal", - invoke_from=InvokeFrom.WEB_APP.value, - from_source="api", - from_end_user_id="user-id", - from_account_id=None, - ) - existing_conversation.id = "existing-conversation-id" - - generator = AdvancedChatAppGenerator() - - conversation, _ = generator._init_generate_records(entity, conversation=existing_conversation) - - assert entity.conversation_id == "existing-conversation-id" - assert conversation is existing_conversation - assert entity.is_new_conversation is False - - -def test_message_cycle_manager_uses_new_conversation_flag(monkeypatch): - app_config = _make_app_config() - entity = _make_generate_entity(app_config) - entity.conversation_id = "existing-conversation-id" - entity.is_new_conversation = True - entity.extras = {"auto_generate_conversation_name": True} - - captured = {} - - class DummyThread: - def __init__(self, **kwargs): - self.kwargs = kwargs - self.started = False - - def start(self): - self.started = True - - def fake_thread(**kwargs): - thread = DummyThread(**kwargs) - captured["thread"] = thread - return thread - - monkeypatch.setattr(message_cycle_manager, "Thread", fake_thread) - - manager = MessageCycleManager(application_generate_entity=entity, task_state=MagicMock()) - thread = manager.generate_conversation_name(conversation_id="existing-conversation-id", query="hello") - - assert thread is captured["thread"] - assert thread.started is True - assert entity.is_new_conversation is False diff --git a/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py b/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py deleted file mode 100644 index 87b8dc51e7..0000000000 --- a/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py +++ /dev/null @@ -1,127 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from core.app.app_config.entities import ( - AppAdditionalFeatures, - EasyUIBasedAppConfig, - EasyUIBasedAppModelConfigFrom, - ModelConfigEntity, - PromptTemplateEntity, -) -from core.app.apps import message_based_app_generator -from core.app.apps.message_based_app_generator import MessageBasedAppGenerator -from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom -from models.model import AppMode, Conversation, Message - - -class DummyModelConf: - def __init__(self, provider: str = "mock-provider", model: str = "mock-model") -> None: - self.provider = provider - self.model = model - - -class DummyCompletionGenerateEntity: - __slots__ = ("app_config", "invoke_from", "user_id", "query", "inputs", "files", "model_conf") - app_config: EasyUIBasedAppConfig - invoke_from: InvokeFrom - user_id: str - query: str - inputs: dict - files: list - model_conf: DummyModelConf - - def __init__(self, app_config: EasyUIBasedAppConfig) -> None: - self.app_config = app_config - self.invoke_from = InvokeFrom.WEB_APP - self.user_id = "user-id" - self.query = "hello" - self.inputs = {} - self.files = [] - self.model_conf = DummyModelConf() - - -def _make_app_config(app_mode: AppMode) -> EasyUIBasedAppConfig: - return EasyUIBasedAppConfig( - tenant_id="tenant-id", - app_id="app-id", - app_mode=app_mode, - app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG, - app_model_config_id="model-config-id", - app_model_config_dict={}, - model=ModelConfigEntity(provider="mock-provider", model="mock-model", mode="chat"), - prompt_template=PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="Hello", - ), - additional_features=AppAdditionalFeatures(), - variables=[], - ) - - -def _make_chat_generate_entity(app_config: EasyUIBasedAppConfig) -> ChatAppGenerateEntity: - return ChatAppGenerateEntity.model_construct( - task_id="task-id", - app_config=app_config, - model_conf=DummyModelConf(), - file_upload_config=None, - conversation_id=None, - inputs={}, - query="hello", - files=[], - parent_message_id=None, - user_id="user-id", - stream=False, - invoke_from=InvokeFrom.WEB_APP, - extras={}, - call_depth=0, - trace_manager=None, - ) - - -@pytest.fixture(autouse=True) -def _mock_db_session(monkeypatch): - session = MagicMock() - - def refresh_side_effect(obj): - if isinstance(obj, Conversation) and obj.id is None: - obj.id = "generated-conversation-id" - if isinstance(obj, Message) and obj.id is None: - obj.id = "generated-message-id" - - session.refresh.side_effect = refresh_side_effect - session.add.return_value = None - session.commit.return_value = None - - monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session)) - return session - - -def test_init_generate_records_skips_conversation_fields_for_non_conversation_entity(): - app_config = _make_app_config(AppMode.COMPLETION) - entity = DummyCompletionGenerateEntity(app_config=app_config) - - generator = MessageBasedAppGenerator() - - conversation, message = generator._init_generate_records(entity, conversation=None) - - assert conversation.id == "generated-conversation-id" - assert message.id == "generated-message-id" - assert hasattr(entity, "conversation_id") is False - assert hasattr(entity, "is_new_conversation") is False - - -def test_init_generate_records_sets_conversation_fields_for_chat_entity(): - app_config = _make_app_config(AppMode.CHAT) - entity = _make_chat_generate_entity(app_config) - - generator = MessageBasedAppGenerator() - - conversation, _ = generator._init_generate_records(entity, conversation=None) - - assert entity.conversation_id == "generated-conversation-id" - assert entity.is_new_conversation is True - assert conversation.id == "generated-conversation-id" diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py deleted file mode 100644 index 97c993928e..0000000000 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ /dev/null @@ -1,287 +0,0 @@ -import sys -import time -from pathlib import Path -from types import ModuleType, SimpleNamespace -from typing import Any - -API_DIR = str(Path(__file__).resolve().parents[5]) -if API_DIR not in sys.path: - sys.path.insert(0, API_DIR) - -import core.workflow.nodes.human_input.entities # noqa: F401 -from core.app.apps.advanced_chat import app_generator as adv_app_gen_module -from core.app.apps.workflow import app_generator as wf_app_gen_module -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.entities.pause_reason import SchedulingPause -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_events import ( - GraphEngineEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunSucceededEvent, -) -from core.workflow.node_events import NodeRunResult, PauseRequestedEvent -from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity, RetryConfig -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable - -if "core.ops.ops_trace_manager" not in sys.modules: - ops_stub = ModuleType("core.ops.ops_trace_manager") - - class _StubTraceQueueManager: - def __init__(self, *_, **__): - pass - - ops_stub.TraceQueueManager = _StubTraceQueueManager - sys.modules["core.ops.ops_trace_manager"] = ops_stub - - -class _StubToolNodeData(BaseNodeData): - pause_on: bool = False - - -class _StubToolNode(Node[_StubToolNodeData]): - node_type = NodeType.TOOL - - @classmethod - def version(cls) -> str: - return "1" - - def init_node_data(self, data): - self._node_data = _StubToolNodeData.model_validate(data) - - def _get_error_strategy(self): - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self): - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - - def _run(self): - if self.node_data.pause_on: - yield PauseRequestedEvent(reason=SchedulingPause(message="test pause")) - return - - result = NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"value": f"{self.id}-done"}, - ) - yield self._convert_node_run_result_to_graph_node_event(result) - - -def _patch_tool_node(mocker): - original_create_node = DifyNodeFactory.create_node - - def _patched_create_node(self, node_config: dict[str, object]) -> Node: - node_data = node_config.get("data", {}) - if isinstance(node_data, dict) and node_data.get("type") == NodeType.TOOL.value: - return _StubToolNode( - id=str(node_config["id"]), - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - ) - return original_create_node(self, node_config) - - mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node) - - -def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]: - node_data = data.model_dump() - node_data["type"] = node_type.value - return node_data - - -def _build_graph_config(*, pause_on: str | None) -> dict[str, object]: - start_data = StartNodeData(title="start", variables=[]) - tool_data_a = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_a") - tool_data_b = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_b") - tool_data_c = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_c") - end_data = EndNodeData( - title="end", - outputs=[OutputVariableEntity(variable="result", value_selector=["tool_c", "value"])], - desc=None, - ) - - nodes = [ - {"id": "start", "data": _node_data(NodeType.START, start_data)}, - {"id": "tool_a", "data": _node_data(NodeType.TOOL, tool_data_a)}, - {"id": "tool_b", "data": _node_data(NodeType.TOOL, tool_data_b)}, - {"id": "tool_c", "data": _node_data(NodeType.TOOL, tool_data_c)}, - {"id": "end", "data": _node_data(NodeType.END, end_data)}, - ] - edges = [ - {"source": "start", "target": "tool_a"}, - {"source": "tool_a", "target": "tool_b"}, - {"source": "tool_b", "target": "tool_c"}, - {"source": "tool_c", "target": "end"}, - ] - return {"nodes": nodes, "edges": edges} - - -def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> Graph: - graph_config = _build_graph_config(pause_on=pause_on) - params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config=graph_config, - user_id="user", - user_from="account", - invoke_from="service-api", - call_depth=0, - ) - - node_factory = DifyNodeFactory( - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - return Graph.init(graph_config=graph_config, node_factory=node_factory) - - -def _build_runtime_state(run_id: str) -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), - user_inputs={}, - conversation_variables=[], - ) - variable_pool.system_variables.workflow_execution_id = run_id - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _run_with_optional_pause(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> list[GraphEngineEvent]: - command_channel = InMemoryChannel() - graph = _build_graph(runtime_state, pause_on=pause_on) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=command_channel, - ) - - events: list[GraphEngineEvent] = [] - for event in engine.run(): - events.append(event) - return events - - -def _node_successes(events: list[GraphEngineEvent]) -> list[str]: - return [evt.node_id for evt in events if isinstance(evt, NodeRunSucceededEvent)] - - -def test_workflow_app_pause_resume_matches_baseline(mocker): - _patch_tool_node(mocker) - - baseline_state = _build_runtime_state("baseline") - baseline_events = _run_with_optional_pause(baseline_state, pause_on=None) - assert isinstance(baseline_events[-1], GraphRunSucceededEvent) - baseline_nodes = _node_successes(baseline_events) - baseline_outputs = baseline_state.outputs - - paused_state = _build_runtime_state("paused-run") - paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a") - assert isinstance(paused_events[-1], GraphRunPausedEvent) - paused_nodes = _node_successes(paused_events) - snapshot = paused_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - generator = wf_app_gen_module.WorkflowAppGenerator() - - def _fake_generate(**kwargs): - state: GraphRuntimeState = kwargs["graph_runtime_state"] - events = _run_with_optional_pause(state, pause_on=None) - return _node_successes(events) - - mocker.patch.object(generator, "_generate", side_effect=_fake_generate) - - resumed_nodes = generator.resume( - app_model=SimpleNamespace(mode="workflow"), - workflow=SimpleNamespace(), - user=SimpleNamespace(), - application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API), - graph_runtime_state=resumed_state, - workflow_execution_repository=SimpleNamespace(), - workflow_node_execution_repository=SimpleNamespace(), - ) - - assert paused_nodes + resumed_nodes == baseline_nodes - assert resumed_state.outputs == baseline_outputs - - -def test_advanced_chat_pause_resume_matches_baseline(mocker): - _patch_tool_node(mocker) - - baseline_state = _build_runtime_state("adv-baseline") - baseline_events = _run_with_optional_pause(baseline_state, pause_on=None) - assert isinstance(baseline_events[-1], GraphRunSucceededEvent) - baseline_nodes = _node_successes(baseline_events) - baseline_outputs = baseline_state.outputs - - paused_state = _build_runtime_state("adv-paused") - paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a") - assert isinstance(paused_events[-1], GraphRunPausedEvent) - paused_nodes = _node_successes(paused_events) - snapshot = paused_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - generator = adv_app_gen_module.AdvancedChatAppGenerator() - - def _fake_generate(**kwargs): - state: GraphRuntimeState = kwargs["graph_runtime_state"] - events = _run_with_optional_pause(state, pause_on=None) - return _node_successes(events) - - mocker.patch.object(generator, "_generate", side_effect=_fake_generate) - - resumed_nodes = generator.resume( - app_model=SimpleNamespace(mode="workflow"), - workflow=SimpleNamespace(), - user=SimpleNamespace(), - conversation=SimpleNamespace(id="conv"), - message=SimpleNamespace(id="msg"), - application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API), - workflow_execution_repository=SimpleNamespace(), - workflow_node_execution_repository=SimpleNamespace(), - graph_runtime_state=resumed_state, - ) - - assert paused_nodes + resumed_nodes == baseline_nodes - assert resumed_state.outputs == baseline_outputs - - -def test_resume_emits_resumption_start_reason(mocker) -> None: - _patch_tool_node(mocker) - - paused_state = _build_runtime_state("resume-reason") - paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a") - initial_start = next(event for event in paused_events if isinstance(event, GraphRunStartedEvent)) - assert initial_start.reason == WorkflowStartReason.INITIAL - - resumed_state = GraphRuntimeState.from_snapshot(paused_state.dumps()) - resumed_events = _run_with_optional_pause(resumed_state, pause_on=None) - resume_start = next(event for event in resumed_events if isinstance(event, GraphRunStartedEvent)) - assert resume_start.reason == WorkflowStartReason.RESUMPTION diff --git a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py deleted file mode 100644 index 7b5447c01e..0000000000 --- a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py +++ /dev/null @@ -1,80 +0,0 @@ -from __future__ import annotations - -import json -import queue - -import pytest - -from core.app.apps.message_based_app_generator import MessageBasedAppGenerator -from core.app.entities.task_entities import StreamEvent -from models.model import AppMode - - -class FakeSubscription: - def __init__(self, message_queue: queue.Queue[bytes], state: dict[str, bool]) -> None: - self._queue = message_queue - self._state = state - self._closed = False - - def __enter__(self): - self._state["subscribed"] = True - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - def close(self) -> None: - self._closed = True - - def receive(self, timeout: float | None = 0.1) -> bytes | None: - if self._closed: - return None - try: - if timeout is None: - return self._queue.get() - return self._queue.get(timeout=timeout) - except queue.Empty: - return None - - -class FakeTopic: - def __init__(self) -> None: - self._queue: queue.Queue[bytes] = queue.Queue() - self._state = {"subscribed": False} - - def subscribe(self) -> FakeSubscription: - return FakeSubscription(self._queue, self._state) - - def publish(self, payload: bytes) -> None: - self._queue.put(payload) - - @property - def subscribed(self) -> bool: - return self._state["subscribed"] - - -def test_retrieve_events_calls_on_subscribe_after_subscription(monkeypatch): - topic = FakeTopic() - - def fake_get_response_topic(cls, app_mode, workflow_run_id): - return topic - - monkeypatch.setattr(MessageBasedAppGenerator, "get_response_topic", classmethod(fake_get_response_topic)) - - def on_subscribe() -> None: - assert topic.subscribed is True - event = {"event": StreamEvent.WORKFLOW_FINISHED.value} - topic.publish(json.dumps(event).encode()) - - generator = MessageBasedAppGenerator.retrieve_events( - AppMode.WORKFLOW, - "workflow-run-id", - idle_timeout=0.5, - on_subscribe=on_subscribe, - ) - - assert next(generator) == StreamEvent.PING.value - event = next(generator) - assert event["event"] == StreamEvent.WORKFLOW_FINISHED.value - with pytest.raises(StopIteration): - next(generator) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py index 7e8367c6c4..83ac3a5591 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py @@ -1,6 +1,3 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock - from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator @@ -20,193 +17,3 @@ def test_should_prepare_user_inputs_keeps_validation_when_flag_false(): args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: False} assert WorkflowAppGenerator()._should_prepare_user_inputs(args) - - -def test_resume_delegates_to_generate(mocker): - generator = WorkflowAppGenerator() - mock_generate = mocker.patch.object(generator, "_generate", return_value="ok") - - application_generate_entity = SimpleNamespace(stream=False, invoke_from="debugger") - runtime_state = MagicMock(name="runtime-state") - pause_config = MagicMock(name="pause-config") - - result = generator.resume( - app_model=MagicMock(), - workflow=MagicMock(), - user=MagicMock(), - application_generate_entity=application_generate_entity, - graph_runtime_state=runtime_state, - workflow_execution_repository=MagicMock(), - workflow_node_execution_repository=MagicMock(), - graph_engine_layers=("layer",), - pause_state_config=pause_config, - variable_loader=MagicMock(), - ) - - assert result == "ok" - mock_generate.assert_called_once() - kwargs = mock_generate.call_args.kwargs - assert kwargs["graph_runtime_state"] is runtime_state - assert kwargs["pause_state_config"] is pause_config - assert kwargs["streaming"] is False - assert kwargs["invoke_from"] == "debugger" - - -def test_generate_appends_pause_layer_and_forwards_state(mocker): - generator = WorkflowAppGenerator() - - mock_queue_manager = MagicMock() - mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=mock_queue_manager) - - fake_current_app = MagicMock() - fake_current_app._get_current_object.return_value = MagicMock() - mocker.patch("core.app.apps.workflow.app_generator.current_app", fake_current_app) - - mocker.patch( - "core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert", - return_value="converted", - ) - mocker.patch.object(WorkflowAppGenerator, "_handle_response", return_value="response") - mocker.patch.object(WorkflowAppGenerator, "_get_draft_var_saver_factory", return_value=MagicMock()) - - pause_layer = MagicMock(name="pause-layer") - mocker.patch( - "core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", - return_value=pause_layer, - ) - - dummy_session = MagicMock() - dummy_session.close = MagicMock() - mocker.patch("core.app.apps.workflow.app_generator.db.session", dummy_session) - - worker_kwargs: dict[str, object] = {} - - class DummyThread: - def __init__(self, target, kwargs): - worker_kwargs["target"] = target - worker_kwargs["kwargs"] = kwargs - - def start(self): - return None - - mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", DummyThread) - - app_model = SimpleNamespace(mode="workflow") - app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="wf") - application_generate_entity = SimpleNamespace( - task_id="task", - user_id="user", - invoke_from="service-api", - app_config=app_config, - files=[], - stream=True, - workflow_execution_id="run", - ) - - graph_runtime_state = MagicMock() - - result = generator._generate( - app_model=app_model, - workflow=MagicMock(), - user=MagicMock(), - application_generate_entity=application_generate_entity, - invoke_from="service-api", - workflow_execution_repository=MagicMock(), - workflow_node_execution_repository=MagicMock(), - streaming=True, - graph_engine_layers=("base-layer",), - graph_runtime_state=graph_runtime_state, - pause_state_config=SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner"), - ) - - assert result == "converted" - assert worker_kwargs["kwargs"]["graph_engine_layers"] == ("base-layer", pause_layer) - assert worker_kwargs["kwargs"]["graph_runtime_state"] is graph_runtime_state - - -def test_resume_path_runs_worker_with_runtime_state(mocker): - generator = WorkflowAppGenerator() - runtime_state = MagicMock(name="runtime-state") - - pause_layer = MagicMock(name="pause-layer") - mocker.patch("core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", return_value=pause_layer) - - queue_manager = MagicMock() - mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=queue_manager) - - mocker.patch.object(generator, "_handle_response", return_value="raw-response") - mocker.patch( - "core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert", - side_effect=lambda response, invoke_from: response, - ) - - fake_db = SimpleNamespace(session=MagicMock(), engine=MagicMock()) - mocker.patch("core.app.apps.workflow.app_generator.db", fake_db) - - workflow = SimpleNamespace( - id="workflow", tenant_id="tenant", app_id="app", graph_dict={}, type="workflow", version="1" - ) - end_user = SimpleNamespace(session_id="end-user-session") - app_record = SimpleNamespace(id="app") - - session = MagicMock() - session.__enter__.return_value = session - session.__exit__.return_value = False - session.scalar.side_effect = [workflow, end_user, app_record] - mocker.patch("core.app.apps.workflow.app_generator.session_factory", return_value=session) - - runner_instance = MagicMock() - - def runner_ctor(**kwargs): - assert kwargs["graph_runtime_state"] is runtime_state - return runner_instance - - mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppRunner", side_effect=runner_ctor) - - class ImmediateThread: - def __init__(self, target, kwargs): - target(**kwargs) - - def start(self): - return None - - mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", ImmediateThread) - - mocker.patch( - "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", - return_value=MagicMock(), - ) - mocker.patch( - "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", - return_value=MagicMock(), - ) - - pause_config = SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner") - - app_model = SimpleNamespace(mode="workflow") - app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="workflow") - application_generate_entity = SimpleNamespace( - task_id="task", - user_id="user", - invoke_from="service-api", - app_config=app_config, - files=[], - stream=True, - workflow_execution_id="run", - trace_manager=MagicMock(), - ) - - result = generator.resume( - app_model=app_model, - workflow=workflow, - user=MagicMock(), - application_generate_entity=application_generate_entity, - graph_runtime_state=runtime_state, - workflow_execution_repository=MagicMock(), - workflow_node_execution_repository=MagicMock(), - pause_state_config=pause_config, - ) - - assert result == "raw-response" - runner_instance.run.assert_called_once() - queue_manager.graph_runtime_state = runtime_state diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py deleted file mode 100644 index f4efb240c0..0000000000 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py +++ /dev/null @@ -1,59 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner -from core.app.entities.queue_entities import QueueWorkflowPausedEvent -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.graph_events.graph import GraphRunPausedEvent - - -class _DummyQueueManager: - def __init__(self): - self.published = [] - - def publish(self, event, _from): - self.published.append(event) - - -class _DummyRuntimeState: - def get_paused_nodes(self): - return ["node-1"] - - -class _DummyGraphEngine: - def __init__(self): - self.graph_runtime_state = _DummyRuntimeState() - - -class _DummyWorkflowEntry: - def __init__(self): - self.graph_engine = _DummyGraphEngine() - - -def test_handle_pause_event_enqueues_email_task(monkeypatch: pytest.MonkeyPatch): - queue_manager = _DummyQueueManager() - runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app-id") - workflow_entry = _DummyWorkflowEntry() - - reason = HumanInputRequired( - form_id="form-123", - form_content="content", - inputs=[], - actions=[], - node_id="node-1", - node_title="Review", - ) - event = GraphRunPausedEvent(reasons=[reason], outputs={}) - - email_task = MagicMock() - monkeypatch.setattr("core.app.apps.workflow_app_runner.dispatch_human_input_email_task", email_task) - - runner._handle_event(workflow_entry, event) - - email_task.apply_async.assert_called_once() - kwargs = email_task.apply_async.call_args.kwargs["kwargs"] - assert kwargs["form_id"] == "form-123" - assert kwargs["node_title"] == "Review" - - assert any(isinstance(evt, QueueWorkflowPausedEvent) for evt in queue_manager.published) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py deleted file mode 100644 index c30b925d88..0000000000 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ /dev/null @@ -1,183 +0,0 @@ -from datetime import UTC, datetime -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from core.app.apps.common import workflow_response_converter -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from core.app.apps.workflow.app_runner import WorkflowAppRunner -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.entities.queue_entities import QueueWorkflowPausedEvent -from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph_events.graph import GraphRunPausedEvent -from core.workflow.nodes.human_input.entities import FormInput, UserAction -from core.workflow.nodes.human_input.enums import FormInputType -from core.workflow.system_variable import SystemVariable -from models.account import Account - - -class _RecordingWorkflowAppRunner(WorkflowAppRunner): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.published_events = [] - - def _publish_event(self, event): - self.published_events.append(event) - - -class _FakeRuntimeState: - def get_paused_nodes(self): - return ["node-pause-1"] - - -def _build_runner(): - app_entity = SimpleNamespace( - app_config=SimpleNamespace(app_id="app-id"), - inputs={}, - files=[], - invoke_from=InvokeFrom.SERVICE_API, - single_iteration_run=None, - single_loop_run=None, - workflow_execution_id="run-id", - user_id="user-id", - ) - workflow = SimpleNamespace( - graph_dict={}, - tenant_id="tenant-id", - environment_variables={}, - id="workflow-id", - ) - queue_manager = SimpleNamespace(publish=lambda event, pub_from: None) - return _RecordingWorkflowAppRunner( - application_generate_entity=app_entity, - queue_manager=queue_manager, - variable_loader=MagicMock(), - workflow=workflow, - system_user_id="sys-user", - root_node_id=None, - workflow_execution_repository=MagicMock(), - workflow_node_execution_repository=MagicMock(), - graph_engine_layers=(), - graph_runtime_state=None, - ) - - -def test_graph_run_paused_event_emits_queue_pause_event(): - runner = _build_runner() - reason = HumanInputRequired( - form_id="form-1", - form_content="content", - inputs=[], - actions=[], - node_id="node-human", - node_title="Human Step", - form_token="tok", - ) - event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"}) - workflow_entry = SimpleNamespace( - graph_engine=SimpleNamespace(graph_runtime_state=_FakeRuntimeState()), - ) - - runner._handle_event(workflow_entry, event) - - assert len(runner.published_events) == 1 - queue_event = runner.published_events[0] - assert isinstance(queue_event, QueueWorkflowPausedEvent) - assert queue_event.reasons == [reason] - assert queue_event.outputs == {"foo": "bar"} - assert queue_event.paused_nodes == ["node-pause-1"] - - -def _build_converter(): - application_generate_entity = SimpleNamespace( - inputs={}, - files=[], - invoke_from=InvokeFrom.SERVICE_API, - app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"), - ) - system_variables = SystemVariable( - user_id="user", - app_id="app-id", - workflow_id="workflow-id", - workflow_execution_id="run-id", - ) - user = MagicMock(spec=Account) - user.id = "account-id" - user.name = "Tester" - user.email = "tester@example.com" - return WorkflowResponseConverter( - application_generate_entity=application_generate_entity, - user=user, - system_variables=system_variables, - ) - - -def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.MonkeyPatch): - converter = _build_converter() - converter.workflow_start_to_stream_response( - task_id="task", - workflow_run_id="run-id", - workflow_id="workflow-id", - reason=WorkflowStartReason.INITIAL, - ) - - expiration_time = datetime(2024, 1, 1, tzinfo=UTC) - - class _FakeSession: - def execute(self, _stmt): - return [("form-1", expiration_time)] - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession()) - monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object())) - - reason = HumanInputRequired( - form_id="form-1", - form_content="Rendered", - inputs=[ - FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None), - ], - actions=[UserAction(id="approve", title="Approve")], - display_in_ui=True, - node_id="node-id", - node_title="Human Step", - form_token="token", - ) - queue_event = QueueWorkflowPausedEvent( - reasons=[reason], - outputs={"answer": "value"}, - paused_nodes=["node-id"], - ) - - runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0) - responses = converter.workflow_pause_to_stream_response( - event=queue_event, - task_id="task", - graph_runtime_state=runtime_state, - ) - - assert isinstance(responses[-1], WorkflowPauseStreamResponse) - pause_resp = responses[-1] - assert pause_resp.workflow_run_id == "run-id" - assert pause_resp.data.paused_nodes == ["node-id"] - assert pause_resp.data.outputs == {} - assert pause_resp.data.reasons[0]["form_id"] == "form-1" - assert pause_resp.data.reasons[0]["display_in_ui"] is True - - assert isinstance(responses[0], HumanInputRequiredResponse) - hi_resp = responses[0] - assert hi_resp.data.form_id == "form-1" - assert hi_resp.data.node_id == "node-id" - assert hi_resp.data.node_title == "Human Step" - assert hi_resp.data.inputs[0].output_variable_name == "field" - assert hi_resp.data.actions[0].id == "approve" - assert hi_resp.data.display_in_ui is True - assert hi_resp.data.expiration_time == int(expiration_time.timestamp()) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py deleted file mode 100644 index 32cb1ed47c..0000000000 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ /dev/null @@ -1,96 +0,0 @@ -import time -from contextlib import contextmanager -from unittest.mock import MagicMock - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.apps.base_app_queue_manager import AppQueueManager -from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline -from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.entities.queue_entities import QueueWorkflowStartedEvent -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.account import Account -from models.model import AppMode - - -def _build_workflow_app_config() -> WorkflowUIBasedAppConfig: - return WorkflowUIBasedAppConfig( - tenant_id="tenant-id", - app_id="app-id", - app_mode=AppMode.WORKFLOW, - workflow_id="workflow-id", - ) - - -def _build_generate_entity(run_id: str) -> WorkflowAppGenerateEntity: - return WorkflowAppGenerateEntity( - task_id="task-id", - app_config=_build_workflow_app_config(), - inputs={}, - files=[], - user_id="user-id", - stream=False, - invoke_from=InvokeFrom.SERVICE_API, - workflow_execution_id=run_id, - ) - - -def _build_runtime_state(run_id: str) -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable(workflow_execution_id=run_id), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -@contextmanager -def _noop_session(): - yield MagicMock() - - -def _build_pipeline(run_id: str) -> WorkflowAppGenerateTaskPipeline: - queue_manager = MagicMock(spec=AppQueueManager) - queue_manager.invoke_from = InvokeFrom.SERVICE_API - queue_manager.graph_runtime_state = _build_runtime_state(run_id) - workflow = MagicMock() - workflow.id = "workflow-id" - workflow.features_dict = {} - user = Account(name="user", email="user@example.com") - pipeline = WorkflowAppGenerateTaskPipeline( - application_generate_entity=_build_generate_entity(run_id), - workflow=workflow, - queue_manager=queue_manager, - user=user, - stream=False, - draft_var_saver_factory=MagicMock(), - ) - pipeline._database_session = _noop_session - return pipeline - - -def test_workflow_app_log_saved_only_on_initial_start() -> None: - run_id = "run-initial" - pipeline = _build_pipeline(run_id) - pipeline._save_workflow_app_log = MagicMock() - - event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.INITIAL) - list(pipeline._handle_workflow_started_event(event)) - - pipeline._save_workflow_app_log.assert_called_once() - _, kwargs = pipeline._save_workflow_app_log.call_args - assert kwargs["workflow_run_id"] == run_id - assert pipeline._workflow_execution_id == run_id - - -def test_workflow_app_log_skipped_on_resumption_start() -> None: - run_id = "run-resume" - pipeline = _build_pipeline(run_id) - pipeline._save_workflow_app_log = MagicMock() - - event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.RESUMPTION) - list(pipeline._handle_workflow_started_event(event)) - - pipeline._save_workflow_app_log.assert_not_called() - assert pipeline._workflow_execution_id == run_id diff --git a/api/tests/unit_tests/core/app/entities/test_app_invoke_entities.py b/api/tests/unit_tests/core/app/entities/test_app_invoke_entities.py deleted file mode 100644 index 86c80985c4..0000000000 --- a/api/tests/unit_tests/core/app/entities/test_app_invoke_entities.py +++ /dev/null @@ -1,143 +0,0 @@ -import json -from collections.abc import Callable -from dataclasses import dataclass - -import pytest - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import ( - AdvancedChatAppGenerateEntity, - InvokeFrom, - WorkflowAppGenerateEntity, -) -from core.app.layers.pause_state_persist_layer import ( - WorkflowResumptionContext, - _AdvancedChatAppGenerateEntityWrapper, - _WorkflowGenerateEntityWrapper, -) -from core.ops.ops_trace_manager import TraceQueueManager -from models.model import AppMode - - -class TraceQueueManagerStub(TraceQueueManager): - """Minimal TraceQueueManager stub that avoids Flask dependencies.""" - - def __init__(self): - # Skip parent initialization to avoid starting timers or accessing Flask globals. - pass - - -def _build_workflow_app_config(app_mode: AppMode) -> WorkflowUIBasedAppConfig: - return WorkflowUIBasedAppConfig( - tenant_id="tenant-id", - app_id="app-id", - app_mode=app_mode, - workflow_id=f"{app_mode.value}-workflow-id", - ) - - -def _create_workflow_generate_entity(trace_manager: TraceQueueManager | None = None) -> WorkflowAppGenerateEntity: - return WorkflowAppGenerateEntity( - task_id="workflow-task", - app_config=_build_workflow_app_config(AppMode.WORKFLOW), - inputs={"topic": "serialization"}, - files=[], - user_id="user-workflow", - stream=True, - invoke_from=InvokeFrom.DEBUGGER, - call_depth=1, - trace_manager=trace_manager, - workflow_execution_id="workflow-exec-id", - extras={"external_trace_id": "trace-id"}, - ) - - -def _create_advanced_chat_generate_entity( - trace_manager: TraceQueueManager | None = None, -) -> AdvancedChatAppGenerateEntity: - return AdvancedChatAppGenerateEntity( - task_id="advanced-task", - app_config=_build_workflow_app_config(AppMode.ADVANCED_CHAT), - conversation_id="conversation-id", - inputs={"topic": "roundtrip"}, - files=[], - user_id="user-advanced", - stream=False, - invoke_from=InvokeFrom.DEBUGGER, - query="Explain serialization", - extras={"auto_generate_conversation_name": True}, - trace_manager=trace_manager, - workflow_run_id="workflow-run-id", - ) - - -def test_workflow_app_generate_entity_roundtrip_excludes_trace_manager(): - entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub()) - - serialized = entity.model_dump_json() - payload = json.loads(serialized) - - assert "trace_manager" not in payload - - restored = WorkflowAppGenerateEntity.model_validate_json(serialized) - - assert restored.model_dump() == entity.model_dump() - assert restored.trace_manager is None - - -def test_advanced_chat_generate_entity_roundtrip_excludes_trace_manager(): - entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub()) - - serialized = entity.model_dump_json() - payload = json.loads(serialized) - - assert "trace_manager" not in payload - - restored = AdvancedChatAppGenerateEntity.model_validate_json(serialized) - - assert restored.model_dump() == entity.model_dump() - assert restored.trace_manager is None - - -@dataclass(frozen=True) -class ResumptionContextCase: - name: str - context_factory: Callable[[], tuple[WorkflowResumptionContext, type]] - - -def _workflow_resumption_case() -> tuple[WorkflowResumptionContext, type]: - entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub()) - context = WorkflowResumptionContext( - serialized_graph_runtime_state=json.dumps({"state": "workflow"}), - generate_entity=_WorkflowGenerateEntityWrapper(entity=entity), - ) - return context, WorkflowAppGenerateEntity - - -def _advanced_chat_resumption_case() -> tuple[WorkflowResumptionContext, type]: - entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub()) - context = WorkflowResumptionContext( - serialized_graph_runtime_state=json.dumps({"state": "advanced"}), - generate_entity=_AdvancedChatAppGenerateEntityWrapper(entity=entity), - ) - return context, AdvancedChatAppGenerateEntity - - -@pytest.mark.parametrize( - "case", - [ - pytest.param(ResumptionContextCase("workflow", _workflow_resumption_case), id="workflow"), - pytest.param(ResumptionContextCase("advanced_chat", _advanced_chat_resumption_case), id="advanced_chat"), - ], -) -def test_workflow_resumption_context_roundtrip(case: ResumptionContextCase): - context, expected_type = case.context_factory() - - serialized = context.dumps() - restored = WorkflowResumptionContext.loads(serialized) - - assert restored.serialized_graph_runtime_state == context.serialized_graph_runtime_state - entity = restored.get_generate_entity() - assert isinstance(entity, expected_type) - assert entity.model_dump() == context.get_generate_entity().model_dump() - assert entity.trace_manager is None diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py deleted file mode 100644 index a380149554..0000000000 --- a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py +++ /dev/null @@ -1,72 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock - -from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig -from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation -from models.model import AppMode - - -def test_invoke_chat_app_advanced_chat_injects_pause_state_config(mocker): - workflow = MagicMock() - workflow.created_by = "owner-id" - - app = MagicMock() - app.mode = AppMode.ADVANCED_CHAT - app.workflow = workflow - - mocker.patch( - "core.plugin.backwards_invocation.app.db", - SimpleNamespace(engine=MagicMock()), - ) - generator_spy = mocker.patch( - "core.plugin.backwards_invocation.app.AdvancedChatAppGenerator.generate", - return_value={"result": "ok"}, - ) - - result = PluginAppBackwardsInvocation.invoke_chat_app( - app=app, - user=MagicMock(), - conversation_id="conv-1", - query="hello", - stream=False, - inputs={"k": "v"}, - files=[], - ) - - assert result == {"result": "ok"} - call_kwargs = generator_spy.call_args.kwargs - pause_state_config = call_kwargs.get("pause_state_config") - assert isinstance(pause_state_config, PauseStateLayerConfig) - assert pause_state_config.state_owner_user_id == "owner-id" - - -def test_invoke_workflow_app_injects_pause_state_config(mocker): - workflow = MagicMock() - workflow.created_by = "owner-id" - - app = MagicMock() - app.mode = AppMode.WORKFLOW - app.workflow = workflow - - mocker.patch( - "core.plugin.backwards_invocation.app.db", - SimpleNamespace(engine=MagicMock()), - ) - generator_spy = mocker.patch( - "core.plugin.backwards_invocation.app.WorkflowAppGenerator.generate", - return_value={"result": "ok"}, - ) - - result = PluginAppBackwardsInvocation.invoke_workflow_app( - app=app, - user=MagicMock(), - stream=False, - inputs={"k": "v"}, - files=[], - ) - - assert result == {"result": "ok"} - call_kwargs = generator_spy.call_args.kwargs - pause_state_config = call_kwargs.get("pause_state_config") - assert isinstance(pause_state_config, PauseStateLayerConfig) - assert pause_state_config.state_owner_user_id == "owner-id" diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py deleted file mode 100644 index 811ed2143b..0000000000 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ /dev/null @@ -1,574 +0,0 @@ -"""Unit tests for HumanInputFormRepositoryImpl private helpers.""" - -from __future__ import annotations - -import dataclasses -from datetime import datetime -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from core.repositories.human_input_repository import ( - HumanInputFormRecord, - HumanInputFormRepositoryImpl, - HumanInputFormSubmissionRepository, - _WorkspaceMemberInfo, -) -from core.workflow.nodes.human_input.entities import ( - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, - FormDefinition, - MemberRecipient, - UserAction, -) -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus -from libs.datetime_utils import naive_utc_now -from models.human_input import ( - EmailExternalRecipientPayload, - EmailMemberRecipientPayload, - HumanInputFormRecipient, - RecipientType, -) - - -def _build_repository() -> HumanInputFormRepositoryImpl: - return HumanInputFormRepositoryImpl(session_factory=MagicMock(), tenant_id="tenant-id") - - -def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleNamespace]: - created: list[SimpleNamespace] = [] - - def fake_new(cls, form_id: str, delivery_id: str, payload): # type: ignore[no-untyped-def] - recipient = SimpleNamespace( - form_id=form_id, - delivery_id=delivery_id, - recipient_type=payload.TYPE, - recipient_payload=payload.model_dump_json(), - ) - created.append(recipient) - return recipient - - monkeypatch.setattr(HumanInputFormRecipient, "new", classmethod(fake_new)) - return created - - -@pytest.fixture(autouse=True) -def _stub_selectinload(monkeypatch: pytest.MonkeyPatch) -> None: - """Avoid SQLAlchemy mapper configuration in tests using fake sessions.""" - - class _FakeSelect: - def options(self, *_args, **_kwargs): # type: ignore[no-untyped-def] - return self - - def where(self, *_args, **_kwargs): # type: ignore[no-untyped-def] - return self - - monkeypatch.setattr( - "core.repositories.human_input_repository.selectinload", lambda *args, **kwargs: "_loader_option" - ) - monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *args, **kwargs: _FakeSelect()) - - -class TestHumanInputFormRepositoryImplHelpers: - def test_build_email_recipients_with_member_and_external(self, monkeypatch: pytest.MonkeyPatch) -> None: - repo = _build_repository() - session_stub = object() - _patch_recipient_factory(monkeypatch) - - def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def] - assert session is session_stub - assert restrict_to_user_ids == ["member-1"] - return [_WorkspaceMemberInfo(user_id="member-1", email="member@example.com")] - - monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query) - - recipients = repo._build_email_recipients( - session=session_stub, - form_id="form-id", - delivery_id="delivery-id", - recipients_config=EmailRecipients( - whole_workspace=False, - items=[ - MemberRecipient(user_id="member-1"), - ExternalRecipient(email="external@example.com"), - ], - ), - ) - - assert len(recipients) == 2 - member_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_MEMBER) - external_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL) - - member_payload = EmailMemberRecipientPayload.model_validate_json(member_recipient.recipient_payload) - assert member_payload.user_id == "member-1" - assert member_payload.email == "member@example.com" - - external_payload = EmailExternalRecipientPayload.model_validate_json(external_recipient.recipient_payload) - assert external_payload.email == "external@example.com" - - def test_build_email_recipients_skips_unknown_members(self, monkeypatch: pytest.MonkeyPatch) -> None: - repo = _build_repository() - session_stub = object() - created = _patch_recipient_factory(monkeypatch) - - def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def] - assert session is session_stub - assert restrict_to_user_ids == ["missing-member"] - return [] - - monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query) - - recipients = repo._build_email_recipients( - session=session_stub, - form_id="form-id", - delivery_id="delivery-id", - recipients_config=EmailRecipients( - whole_workspace=False, - items=[ - MemberRecipient(user_id="missing-member"), - ExternalRecipient(email="external@example.com"), - ], - ), - ) - - assert len(recipients) == 1 - assert recipients[0].recipient_type == RecipientType.EMAIL_EXTERNAL - assert len(created) == 1 # only external recipient created via factory - - def test_build_email_recipients_whole_workspace_uses_all_members(self, monkeypatch: pytest.MonkeyPatch) -> None: - repo = _build_repository() - session_stub = object() - _patch_recipient_factory(monkeypatch) - - def fake_query(self, session): # type: ignore[no-untyped-def] - assert session is session_stub - return [ - _WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"), - _WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"), - ] - - monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_all_workspace_members", fake_query) - - recipients = repo._build_email_recipients( - session=session_stub, - form_id="form-id", - delivery_id="delivery-id", - recipients_config=EmailRecipients( - whole_workspace=True, - items=[], - ), - ) - - assert len(recipients) == 2 - emails = {EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email for r in recipients} - assert emails == {"member1@example.com", "member2@example.com"} - - def test_build_email_recipients_dedupes_external_by_email(self, monkeypatch: pytest.MonkeyPatch) -> None: - repo = _build_repository() - session_stub = object() - created = _patch_recipient_factory(monkeypatch) - - def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def] - assert session is session_stub - assert restrict_to_user_ids == [] - return [] - - monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query) - - recipients = repo._build_email_recipients( - session=session_stub, - form_id="form-id", - delivery_id="delivery-id", - recipients_config=EmailRecipients( - whole_workspace=False, - items=[ - ExternalRecipient(email="external@example.com"), - ExternalRecipient(email="external@example.com"), - ], - ), - ) - - assert len(recipients) == 1 - assert len(created) == 1 - - def test_build_email_recipients_prefers_member_over_external_by_email( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: - repo = _build_repository() - session_stub = object() - _patch_recipient_factory(monkeypatch) - - def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def] - assert session is session_stub - assert restrict_to_user_ids == ["member-1"] - return [_WorkspaceMemberInfo(user_id="member-1", email="shared@example.com")] - - monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query) - - recipients = repo._build_email_recipients( - session=session_stub, - form_id="form-id", - delivery_id="delivery-id", - recipients_config=EmailRecipients( - whole_workspace=False, - items=[ - MemberRecipient(user_id="member-1"), - ExternalRecipient(email="shared@example.com"), - ], - ), - ) - - assert len(recipients) == 1 - assert recipients[0].recipient_type == RecipientType.EMAIL_MEMBER - - def test_delivery_method_to_model_includes_external_recipients_with_whole_workspace( - self, - monkeypatch: pytest.MonkeyPatch, - ) -> None: - repo = _build_repository() - session_stub = object() - _patch_recipient_factory(monkeypatch) - - def fake_query(self, session): # type: ignore[no-untyped-def] - assert session is session_stub - return [ - _WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"), - _WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"), - ] - - monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_all_workspace_members", fake_query) - - method = EmailDeliveryMethod( - config=EmailDeliveryConfig( - recipients=EmailRecipients( - whole_workspace=True, - items=[ExternalRecipient(email="external@example.com")], - ), - subject="subject", - body="body", - ) - ) - - result = repo._delivery_method_to_model(session=session_stub, form_id="form-id", delivery_method=method) - - assert len(result.recipients) == 3 - member_emails = { - EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email - for r in result.recipients - if r.recipient_type == RecipientType.EMAIL_MEMBER - } - assert member_emails == {"member1@example.com", "member2@example.com"} - external_payload = EmailExternalRecipientPayload.model_validate_json( - next(r for r in result.recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL).recipient_payload - ) - assert external_payload.email == "external@example.com" - - -def _make_form_definition() -> str: - return FormDefinition( - form_content="hello", - inputs=[], - user_actions=[UserAction(id="submit", title="Submit")], - rendered_content="

hello

", - expiration_time=datetime.utcnow(), - ).model_dump_json() - - -@dataclasses.dataclass -class _DummyForm: - id: str - workflow_run_id: str - node_id: str - tenant_id: str - app_id: str - form_definition: str - rendered_content: str - expiration_time: datetime - form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME - created_at: datetime = dataclasses.field(default_factory=naive_utc_now) - selected_action_id: str | None = None - submitted_data: str | None = None - submitted_at: datetime | None = None - submission_user_id: str | None = None - submission_end_user_id: str | None = None - completed_by_recipient_id: str | None = None - status: HumanInputFormStatus = HumanInputFormStatus.WAITING - - -@dataclasses.dataclass -class _DummyRecipient: - id: str - form_id: str - recipient_type: RecipientType - access_token: str - form: _DummyForm | None = None - - -class _FakeScalarResult: - def __init__(self, obj): - self._obj = obj - - def first(self): - if isinstance(self._obj, list): - return self._obj[0] if self._obj else None - return self._obj - - def all(self): - if isinstance(self._obj, list): - return list(self._obj) - if self._obj is None: - return [] - return [self._obj] - - -class _FakeSession: - def __init__( - self, - *, - scalars_result=None, - scalars_results: list[object] | None = None, - forms: dict[str, _DummyForm] | None = None, - recipients: dict[str, _DummyRecipient] | None = None, - ): - if scalars_results is not None: - self._scalars_queue = list(scalars_results) - elif scalars_result is not None: - self._scalars_queue = [scalars_result] - else: - self._scalars_queue = [] - self.forms = forms or {} - self.recipients = recipients or {} - - def scalars(self, _query): - if self._scalars_queue: - result = self._scalars_queue.pop(0) - else: - result = None - return _FakeScalarResult(result) - - def get(self, model_cls, obj_id): # type: ignore[no-untyped-def] - if getattr(model_cls, "__name__", None) == "HumanInputForm": - return self.forms.get(obj_id) - if getattr(model_cls, "__name__", None) == "HumanInputFormRecipient": - return self.recipients.get(obj_id) - return None - - def add(self, _obj): - return None - - def flush(self): - return None - - def refresh(self, _obj): - return None - - def begin(self): - return self - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return None - - -def _session_factory(session: _FakeSession): - class _SessionContext: - def __enter__(self): - return session - - def __exit__(self, exc_type, exc, tb): - return None - - def _factory(*_args, **_kwargs): - return _SessionContext() - - return _factory - - -class TestHumanInputFormRepositoryImplPublicMethods: - def test_get_form_returns_entity_and_recipients(self): - form = _DummyForm( - id="form-1", - workflow_run_id="run-1", - node_id="node-1", - tenant_id="tenant-id", - app_id="app-id", - form_definition=_make_form_definition(), - rendered_content="

hello

", - expiration_time=naive_utc_now(), - ) - recipient = _DummyRecipient( - id="recipient-1", - form_id=form.id, - recipient_type=RecipientType.STANDALONE_WEB_APP, - access_token="token-123", - ) - session = _FakeSession(scalars_results=[form, [recipient]]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") - - entity = repo.get_form(form.workflow_run_id, form.node_id) - - assert entity is not None - assert entity.id == form.id - assert entity.web_app_token == "token-123" - assert len(entity.recipients) == 1 - assert entity.recipients[0].token == "token-123" - - def test_get_form_returns_none_when_missing(self): - session = _FakeSession(scalars_results=[None]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") - - assert repo.get_form("run-1", "node-1") is None - - def test_get_form_returns_unsubmitted_state(self): - form = _DummyForm( - id="form-1", - workflow_run_id="run-1", - node_id="node-1", - tenant_id="tenant-id", - app_id="app-id", - form_definition=_make_form_definition(), - rendered_content="

hello

", - expiration_time=naive_utc_now(), - ) - session = _FakeSession(scalars_results=[form, []]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") - - entity = repo.get_form(form.workflow_run_id, form.node_id) - - assert entity is not None - assert entity.submitted is False - assert entity.selected_action_id is None - assert entity.submitted_data is None - - def test_get_form_returns_submission_when_completed(self): - form = _DummyForm( - id="form-1", - workflow_run_id="run-1", - node_id="node-1", - tenant_id="tenant-id", - app_id="app-id", - form_definition=_make_form_definition(), - rendered_content="

hello

", - expiration_time=naive_utc_now(), - selected_action_id="approve", - submitted_data='{"field": "value"}', - submitted_at=naive_utc_now(), - ) - session = _FakeSession(scalars_results=[form, []]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") - - entity = repo.get_form(form.workflow_run_id, form.node_id) - - assert entity is not None - assert entity.submitted is True - assert entity.selected_action_id == "approve" - assert entity.submitted_data == {"field": "value"} - - -class TestHumanInputFormSubmissionRepository: - def test_get_by_token_returns_record(self): - form = _DummyForm( - id="form-1", - workflow_run_id="run-1", - node_id="node-1", - tenant_id="tenant-1", - app_id="app-1", - form_definition=_make_form_definition(), - rendered_content="

hello

", - expiration_time=naive_utc_now(), - ) - recipient = _DummyRecipient( - id="recipient-1", - form_id=form.id, - recipient_type=RecipientType.STANDALONE_WEB_APP, - access_token="token-123", - form=form, - ) - session = _FakeSession(scalars_result=recipient) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) - - record = repo.get_by_token("token-123") - - assert record is not None - assert record.form_id == form.id - assert record.recipient_type == RecipientType.STANDALONE_WEB_APP - assert record.submitted is False - - def test_get_by_form_id_and_recipient_type_uses_recipient(self): - form = _DummyForm( - id="form-1", - workflow_run_id="run-1", - node_id="node-1", - tenant_id="tenant-1", - app_id="app-1", - form_definition=_make_form_definition(), - rendered_content="

hello

", - expiration_time=naive_utc_now(), - ) - recipient = _DummyRecipient( - id="recipient-1", - form_id=form.id, - recipient_type=RecipientType.STANDALONE_WEB_APP, - access_token="token-123", - form=form, - ) - session = _FakeSession(scalars_result=recipient) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) - - record = repo.get_by_form_id_and_recipient_type( - form_id=form.id, - recipient_type=RecipientType.STANDALONE_WEB_APP, - ) - - assert record is not None - assert record.recipient_id == recipient.id - assert record.access_token == recipient.access_token - - def test_mark_submitted_updates_fields(self, monkeypatch: pytest.MonkeyPatch): - fixed_now = datetime(2024, 1, 1, 0, 0, 0) - monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now) - - form = _DummyForm( - id="form-1", - workflow_run_id="run-1", - node_id="node-1", - tenant_id="tenant-1", - app_id="app-1", - form_definition=_make_form_definition(), - rendered_content="

hello

", - expiration_time=fixed_now, - ) - recipient = _DummyRecipient( - id="recipient-1", - form_id="form-1", - recipient_type=RecipientType.STANDALONE_WEB_APP, - access_token="token-123", - ) - session = _FakeSession( - forms={form.id: form}, - recipients={recipient.id: recipient}, - ) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) - - record: HumanInputFormRecord = repo.mark_submitted( - form_id=form.id, - recipient_id=recipient.id, - selected_action_id="approve", - form_data={"field": "value"}, - submission_user_id="user-1", - submission_end_user_id="end-user-1", - ) - - assert form.selected_action_id == "approve" - assert form.completed_by_recipient_id == recipient.id - assert form.submission_user_id == "user-1" - assert form.submission_end_user_id == "end-user-1" - assert form.submitted_at == fixed_now - assert record.submitted is True - assert record.selected_action_id == "approve" - assert record.submitted_data == {"field": "value"} diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py deleted file mode 100644 index c46e31d90f..0000000000 --- a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py +++ /dev/null @@ -1,33 +0,0 @@ -import pytest - -from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils - - -def test_ensure_no_human_input_nodes_passes_for_non_human_input(): - graph = { - "nodes": [ - { - "id": "start_node", - "data": {"type": "start"}, - } - ] - } - - WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph) - - -def test_ensure_no_human_input_nodes_raises_for_human_input(): - graph = { - "nodes": [ - { - "id": "human_input_node", - "data": {"type": "human-input"}, - } - ] - } - - with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: - WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph) - - assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index bbedfdb6ae..cd45292488 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -55,43 +55,6 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel assert exc_info.value.args == ("oops",) -def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.MonkeyPatch): - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) - - monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) - monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) - - from unittest.mock import MagicMock, Mock - - mock_user = Mock() - monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) - - generate_mock = MagicMock(return_value={"data": {}}) - monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock) - monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) - - list(tool.invoke("test_user", {})) - - call_kwargs = generate_mock.call_args.kwargs - assert "pause_state_config" in call_kwargs - assert call_kwargs["pause_state_config"] is None - - def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch): """Test that WorkflowTool should generate variable messages when there are outputs""" entity = ToolEntity( diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 1b6d03e36a..deff06fc5d 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -118,6 +118,7 @@ class TestGraphRuntimeState: from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue assert isinstance(queue, InMemoryReadyQueue) + assert state.ready_queue is queue def test_graph_execution_lazy_instantiation(self): state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) diff --git a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py deleted file mode 100644 index 6144df06e0..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Tests for PauseReason discriminated union serialization/deserialization. -""" - -import pytest -from pydantic import BaseModel, ValidationError - -from core.workflow.entities.pause_reason import ( - HumanInputRequired, - PauseReason, - SchedulingPause, -) - - -class _Holder(BaseModel): - """Helper model that embeds PauseReason for union tests.""" - - reason: PauseReason - - -class TestPauseReasonDiscriminator: - """Test suite for PauseReason union discriminator.""" - - @pytest.mark.parametrize( - ("dict_value", "expected"), - [ - pytest.param( - { - "reason": { - "TYPE": "human_input_required", - "form_id": "form_id", - "form_content": "form_content", - "node_id": "node_id", - "node_title": "node_title", - }, - }, - HumanInputRequired( - form_id="form_id", - form_content="form_content", - node_id="node_id", - node_title="node_title", - ), - id="HumanInputRequired", - ), - pytest.param( - { - "reason": { - "TYPE": "scheduled_pause", - "message": "Hold on", - } - }, - SchedulingPause(message="Hold on"), - id="SchedulingPause", - ), - ], - ) - def test_model_validate(self, dict_value, expected): - """Ensure scheduled pause payloads with lowercase TYPE deserialize.""" - holder = _Holder.model_validate(dict_value) - - assert type(holder.reason) == type(expected) - assert holder.reason == expected - - @pytest.mark.parametrize( - "reason", - [ - HumanInputRequired( - form_id="form_id", - form_content="form_content", - node_id="node_id", - node_title="node_title", - ), - SchedulingPause(message="Hold on"), - ], - ids=lambda x: type(x).__name__, - ) - def test_model_construct(self, reason): - holder = _Holder(reason=reason) - assert holder.reason == reason - - def test_model_construct_with_invalid_type(self): - with pytest.raises(ValidationError): - holder = _Holder(reason=object()) # type: ignore - - def test_unknown_type_fails_validation(self): - """Unknown TYPE values should raise a validation error.""" - with pytest.raises(ValidationError): - _Holder.model_validate({"reason": {"TYPE": "UNKNOWN"}}) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py deleted file mode 100644 index 2ef23c7f0f..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Utilities for testing HumanInputNode without database dependencies.""" - -from __future__ import annotations - -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRecipientEntity, - HumanInputFormRepository, -) -from libs.datetime_utils import naive_utc_now - - -class _InMemoryFormRecipient(HumanInputFormRecipientEntity): - """Minimal recipient entity required by the repository interface.""" - - def __init__(self, recipient_id: str, token: str) -> None: - self._id = recipient_id - self._token = token - - @property - def id(self) -> str: - return self._id - - @property - def token(self) -> str: - return self._token - - -@dataclass -class _InMemoryFormEntity(HumanInputFormEntity): - form_id: str - rendered: str - token: str | None = None - action_id: str | None = None - data: Mapping[str, Any] | None = None - is_submitted: bool = False - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() - - @property - def id(self) -> str: - return self.form_id - - @property - def web_app_token(self) -> str | None: - return self.token - - @property - def recipients(self) -> list[HumanInputFormRecipientEntity]: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class InMemoryHumanInputFormRepository(HumanInputFormRepository): - """Pure in-memory repository used by workflow graph engine tests.""" - - def __init__(self) -> None: - self._form_counter = 0 - self.created_params: list[FormCreateParams] = [] - self.created_forms: list[_InMemoryFormEntity] = [] - self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {} - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - self.created_params.append(params) - self._form_counter += 1 - form_id = f"form-{self._form_counter}" - token = f"console-{form_id}" if params.console_recipient_required else f"token-{form_id}" - entity = _InMemoryFormEntity( - form_id=form_id, - rendered=params.rendered_content, - token=token, - ) - self.created_forms.append(entity) - self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity - return entity - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_key.get((workflow_execution_id, node_id)) - - # Convenience helpers for tests ------------------------------------- - - def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None: - """Simulate a human submission for the next repository lookup.""" - - if not self.created_forms: - raise AssertionError("no form has been created to attach submission data") - entity = self.created_forms[-1] - entity.action_id = action_id - entity.data = form_data or {} - entity.is_submitted = True - entity.status_value = HumanInputFormStatus.SUBMITTED - entity.expiration = naive_utc_now() + timedelta(days=1) - - def clear_submission(self) -> None: - if not self.created_forms: - return - for form in self.created_forms: - form.action_id = None - form.data = None - form.is_submitted = False - form.status_value = HumanInputFormStatus.WAITING diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py deleted file mode 100644 index 6038a15211..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py +++ /dev/null @@ -1,74 +0,0 @@ -import queue -import threading -from datetime import datetime - -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher -from core.workflow.graph_events import NodeRunSucceededEvent -from core.workflow.node_events import NodeRunResult - - -class StubExecutionCoordinator: - def __init__(self, paused: bool) -> None: - self._paused = paused - self.mark_complete_called = False - self.failed_error: Exception | None = None - - @property - def aborted(self) -> bool: - return False - - @property - def paused(self) -> bool: - return self._paused - - @property - def execution_complete(self) -> bool: - return False - - def check_scaling(self) -> None: - return None - - def process_commands(self) -> None: - return None - - def mark_complete(self) -> None: - self.mark_complete_called = True - - def mark_failed(self, error: Exception) -> None: - self.failed_error = error - - -class StubEventHandler: - def __init__(self) -> None: - self.events: list[object] = [] - - def dispatch(self, event: object) -> None: - self.events.append(event) - - -def test_dispatcher_drains_events_when_paused() -> None: - event_queue: queue.Queue = queue.Queue() - event = NodeRunSucceededEvent( - id="exec-1", - node_id="node-1", - node_type=NodeType.START, - start_at=datetime.utcnow(), - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), - ) - event_queue.put(event) - - handler = StubEventHandler() - coordinator = StubExecutionCoordinator(paused=True) - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=handler, - execution_coordinator=coordinator, - event_emitter=None, - stop_event=threading.Event(), - ) - - dispatcher._dispatcher_loop() - - assert handler.events == [event] - assert coordinator.mark_complete_called is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py index 53de8908a8..0d67a76169 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py @@ -2,8 +2,6 @@ from unittest.mock import MagicMock -import pytest - from core.workflow.graph_engine.command_processing.command_processor import CommandProcessor from core.workflow.graph_engine.domain.graph_execution import GraphExecution from core.workflow.graph_engine.graph_state_manager import GraphStateManager @@ -50,13 +48,3 @@ def test_handle_pause_noop_when_execution_running() -> None: worker_pool.stop.assert_not_called() state_manager.clear_executing.assert_not_called() - - -def test_has_executing_nodes_requires_pause() -> None: - graph_execution = GraphExecution(workflow_id="workflow") - graph_execution.start() - - coordinator, _, _ = _build_coordinator(graph_execution) - - with pytest.raises(AssertionError): - coordinator.has_executing_nodes() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py deleted file mode 100644 index 65d34c2009..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py +++ /dev/null @@ -1,189 +0,0 @@ -import time -from collections.abc import Mapping - -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.enums import NodeState -from core.workflow.graph import Graph -from core.workflow.graph_engine.graph_state_manager import GraphStateManager -from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_llm_node( - *, - node_id: str, - runtime_state: GraphRuntimeState, - graph_init_params: GraphInitParams, - mock_config: MockConfig, -) -> MockLLMNode: - llm_data = LLMNodeData( - title=f"LLM {node_id}", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=f"Prompt {node_id}", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - return MockLLMNode( - id=llm_config["id"], - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - -def _build_graph(runtime_state: GraphRuntimeState) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config=graph_config, - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - mock_config = MockConfig() - llm_a = _build_llm_node( - node_id="llm_a", - runtime_state=runtime_state, - graph_init_params=graph_init_params, - mock_config=mock_config, - ) - llm_b = _build_llm_node( - node_id="llm_b", - runtime_state=runtime_state, - graph_init_params=graph_init_params, - mock_config=mock_config, - ) - - end_data = EndNodeData(title="End", outputs=[], desc=None) - end_config = {"id": "end", "data": end_data.model_dump()} - end_node = EndNode( - id=end_config["id"], - config=end_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - builder = ( - Graph.new() - .add_root(start_node) - .add_node(llm_a, from_node_id="start") - .add_node(llm_b, from_node_id="start") - .add_node(end_node, from_node_id="llm_a") - ) - return builder.connect(tail="llm_b", head="end").build() - - -def _edge_state_map(graph: Graph) -> Mapping[tuple[str, str, str], NodeState]: - return {(edge.tail, edge.head, edge.source_handle): edge.state for edge in graph.edges.values()} - - -def test_runtime_state_snapshot_restores_graph_states() -> None: - runtime_state = _build_runtime_state() - graph = _build_graph(runtime_state) - runtime_state.attach_graph(graph) - - graph.nodes["llm_a"].state = NodeState.TAKEN - graph.nodes["llm_b"].state = NodeState.SKIPPED - - for edge in graph.edges.values(): - if edge.tail == "start" and edge.head == "llm_a": - edge.state = NodeState.TAKEN - elif edge.tail == "start" and edge.head == "llm_b": - edge.state = NodeState.SKIPPED - elif edge.head == "end" and edge.tail == "llm_a": - edge.state = NodeState.TAKEN - elif edge.head == "end" and edge.tail == "llm_b": - edge.state = NodeState.SKIPPED - - snapshot = runtime_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resumed_graph = _build_graph(resumed_state) - resumed_state.attach_graph(resumed_graph) - - assert resumed_graph.nodes["llm_a"].state == NodeState.TAKEN - assert resumed_graph.nodes["llm_b"].state == NodeState.SKIPPED - assert _edge_state_map(resumed_graph) == _edge_state_map(graph) - - -def test_join_readiness_uses_restored_edge_states() -> None: - runtime_state = _build_runtime_state() - graph = _build_graph(runtime_state) - runtime_state.attach_graph(graph) - - ready_queue = InMemoryReadyQueue() - state_manager = GraphStateManager(graph, ready_queue) - - for edge in graph.get_incoming_edges("end"): - if edge.tail == "llm_a": - edge.state = NodeState.TAKEN - if edge.tail == "llm_b": - edge.state = NodeState.UNKNOWN - - assert state_manager.is_node_ready("end") is False - - for edge in graph.get_incoming_edges("end"): - if edge.tail == "llm_b": - edge.state = NodeState.TAKEN - - assert state_manager.is_node_ready("end") is True - - snapshot = runtime_state.dumps() - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resumed_graph = _build_graph(resumed_state) - resumed_state.attach_graph(resumed_graph) - - resumed_state_manager = GraphStateManager(resumed_graph, InMemoryReadyQueue()) - assert resumed_state_manager.is_node_ready("end") is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index 194d009288..c398e4e8c1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -1,7 +1,5 @@ -import datetime import time from collections.abc import Iterable -from unittest.mock import MagicMock from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.message_entities import PromptMessageRole @@ -16,12 +14,11 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.nodes.human_input import HumanInputNode +from core.workflow.nodes.human_input.entities import HumanInputNodeData from core.workflow.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, @@ -31,21 +28,15 @@ from core.workflow.nodes.llm.entities import ( ) from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode from .test_table_runner import TableTestRunner, WorkflowTestCase -def _build_branching_graph( - mock_config: MockConfig, - form_repository: HumanInputFormRepository, - graph_runtime_state: GraphRuntimeState | None = None, -) -> tuple[Graph, GraphRuntimeState]: +def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} graph_init_params = GraphInitParams( tenant_id="tenant", @@ -58,18 +49,12 @@ def _build_branching_graph( call_depth=0, ) - if graph_runtime_state is None: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="test-execution-id", - ), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + variable_pool = VariablePool( + system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + user_inputs={}, + conversation_variables=[], + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( @@ -108,21 +93,15 @@ def _build_branching_graph( human_data = HumanInputNodeData( title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[ - UserAction(id="primary", title="Primary"), - UserAction(id="secondary", title="Secondary"), - ], + required_variables=["human.input_ready"], + pause_reason="Awaiting human input", ) - human_config = {"id": "human", "data": human_data.model_dump()} human_node = HumanInputNode( id=human_config["id"], config=human_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=form_repository, ) llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") @@ -240,18 +219,8 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: for scenario in branch_scenarios: runner = TableTestRunner() - mock_create_repo = MagicMock(spec=HumanInputFormRepository) - mock_create_repo.get_form.return_value = None - mock_form_entity = MagicMock(spec=HumanInputFormEntity) - mock_form_entity.id = "test_form_id" - mock_form_entity.web_app_token = "test_web_app_token" - mock_form_entity.recipients = [] - mock_form_entity.rendered_content = "rendered" - mock_form_entity.submitted = False - mock_create_repo.create_form.return_value = mock_form_entity - - def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]: - return _build_branching_graph(mock_config, mock_create_repo) + def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]: + return _build_branching_graph(mock_config) initial_case = WorkflowTestCase( description="HumanInput pause before branching decision", @@ -273,16 +242,23 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: assert initial_result.success, initial_result.event_mismatch_details assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events) + graph_runtime_state = initial_result.graph_runtime_state + graph = initial_result.graph + assert graph_runtime_state is not None + assert graph is not None + + graph_runtime_state.variable_pool.add(("human", "input_ready"), True) + graph_runtime_state.variable_pool.add(("human", "edge_source_handle"), scenario["handle"]) + graph_runtime_state.graph_execution.pause_reason = None + pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"]) post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"]) - expected_pre_chunk_events_in_resumption = [ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunHumanInputFormFilledEvent, - ] expected_resume_sequence: list[type] = ( - expected_pre_chunk_events_in_resumption + [ + GraphRunStartedEvent, + NodeRunStartedEvent, + ] + [NodeRunStreamChunkEvent] * pre_chunk_count + [ NodeRunSucceededEvent, @@ -297,25 +273,11 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: ] ) - mock_get_repo = MagicMock(spec=HumanInputFormRepository) - submitted_form = MagicMock(spec=HumanInputFormEntity) - submitted_form.id = mock_form_entity.id - submitted_form.web_app_token = mock_form_entity.web_app_token - submitted_form.recipients = [] - submitted_form.rendered_content = mock_form_entity.rendered_content - submitted_form.submitted = True - submitted_form.selected_action_id = scenario["handle"] - submitted_form.submitted_data = {} - submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - mock_get_repo.get_form.return_value = submitted_form - def resume_graph_factory( - initial_result=initial_result, mock_get_repo=mock_get_repo + graph_snapshot: Graph = graph, + state_snapshot: GraphRuntimeState = graph_runtime_state, ) -> tuple[Graph, GraphRuntimeState]: - assert initial_result.graph_runtime_state is not None - serialized_runtime_state = initial_result.graph_runtime_state.dumps() - resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) - return _build_branching_graph(mock_config, mock_get_repo, resume_runtime_state) + return graph_snapshot, state_snapshot resume_case = WorkflowTestCase( description=f"HumanInput resumes via {scenario['handle']} branch", @@ -359,8 +321,7 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: for index, event in enumerate(resume_events) if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index ] - expected_pre_chunk_events_count_in_resumption = len(expected_pre_chunk_events_in_resumption) - assert pre_indices == list(range(expected_pre_chunk_events_count_in_resumption, human_success_index)) + assert pre_indices == list(range(2, 2 + pre_chunk_count)) resume_chunk_indices = [ index diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index d8f229205b..ece69b080b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -1,6 +1,4 @@ -import datetime import time -from unittest.mock import MagicMock from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.message_entities import PromptMessageRole @@ -15,12 +13,11 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.nodes.human_input import HumanInputNode +from core.workflow.nodes.human_input.entities import HumanInputNodeData from core.workflow.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, @@ -30,21 +27,15 @@ from core.workflow.nodes.llm.entities import ( ) from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode from .test_table_runner import TableTestRunner, WorkflowTestCase -def _build_llm_human_llm_graph( - mock_config: MockConfig, - form_repository: HumanInputFormRepository, - graph_runtime_state: GraphRuntimeState | None = None, -) -> tuple[Graph, GraphRuntimeState]: +def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} graph_init_params = GraphInitParams( tenant_id="tenant", @@ -57,15 +48,12 @@ def _build_llm_human_llm_graph( call_depth=0, ) - if graph_runtime_state is None: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id," - ), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + variable_pool = VariablePool( + system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + user_inputs={}, + conversation_variables=[], + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( @@ -104,21 +92,15 @@ def _build_llm_human_llm_graph( human_data = HumanInputNodeData( title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[ - UserAction(id="accept", title="Accept"), - UserAction(id="reject", title="Reject"), - ], + required_variables=["human.input_ready"], + pause_reason="Awaiting human input", ) - human_config = {"id": "human", "data": human_data.model_dump()} human_node = HumanInputNode( id=human_config["id"], config=human_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=form_repository, ) llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt") @@ -148,7 +130,7 @@ def _build_llm_human_llm_graph( .add_root(start_node) .add_node(llm_first) .add_node(human_node) - .add_node(llm_second, source_handle="accept") + .add_node(llm_second) .add_node(end_node) .build() ) @@ -185,18 +167,8 @@ def test_human_input_llm_streaming_order_across_pause() -> None: GraphRunPausedEvent, # graph run pauses awaiting resume ] - mock_create_repo = MagicMock(spec=HumanInputFormRepository) - mock_create_repo.get_form.return_value = None - mock_form_entity = MagicMock(spec=HumanInputFormEntity) - mock_form_entity.id = "test_form_id" - mock_form_entity.web_app_token = "test_web_app_token" - mock_form_entity.recipients = [] - mock_form_entity.rendered_content = "rendered" - mock_form_entity.submitted = False - mock_create_repo.create_form.return_value = mock_form_entity - def graph_factory() -> tuple[Graph, GraphRuntimeState]: - return _build_llm_human_llm_graph(mock_config, mock_create_repo) + return _build_llm_human_llm_graph(mock_config) initial_case = WorkflowTestCase( description="HumanInput pause preserves LLM streaming order", @@ -238,8 +210,6 @@ def test_human_input_llm_streaming_order_across_pause() -> None: expected_resume_sequence: list[type] = [ GraphRunStartedEvent, # resumed graph run begins NodeRunStartedEvent, # human node restarts - # Form Filled should be generated first, then the node execution ends and stream chunk is generated. - NodeRunHumanInputFormFilledEvent, NodeRunStreamChunkEvent, # cached llm_initial chunk 1 NodeRunStreamChunkEvent, # cached llm_initial chunk 2 NodeRunStreamChunkEvent, # cached llm_initial final chunk @@ -255,27 +225,12 @@ def test_human_input_llm_streaming_order_across_pause() -> None: GraphRunSucceededEvent, # graph run succeeds after resume ] - mock_get_repo = MagicMock(spec=HumanInputFormRepository) - submitted_form = MagicMock(spec=HumanInputFormEntity) - submitted_form.id = mock_form_entity.id - submitted_form.web_app_token = mock_form_entity.web_app_token - submitted_form.recipients = [] - submitted_form.rendered_content = mock_form_entity.rendered_content - submitted_form.submitted = True - submitted_form.selected_action_id = "accept" - submitted_form.submitted_data = {} - submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - mock_get_repo.get_form.return_value = submitted_form - def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]: - # restruct the graph runtime state - serialized_runtime_state = initial_result.graph_runtime_state.dumps() - resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) - return _build_llm_human_llm_graph( - mock_config, - mock_get_repo, - resume_runtime_state, - ) + assert graph_runtime_state is not None + assert graph is not None + graph_runtime_state.variable_pool.add(("human", "input_ready"), True) + graph_runtime_state.graph_execution.pause_reason = None + return graph, graph_runtime_state resume_case = WorkflowTestCase( description="HumanInput resume continues LLM streaming order", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py deleted file mode 100644 index a6aab81f6c..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ /dev/null @@ -1,270 +0,0 @@ -import time -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any, Protocol - -from core.workflow.entities import GraphInitParams -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.config import GraphEngineConfig -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunSucceededEvent, -) -from core.workflow.nodes.base.entities import OutputVariableEntity -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now - - -class PauseStateStore(Protocol): - def save(self, runtime_state: GraphRuntimeState) -> None: ... - - def load(self) -> GraphRuntimeState: ... - - -class InMemoryPauseStore: - def __init__(self) -> None: - self._snapshot: str | None = None - - def save(self, runtime_state: GraphRuntimeState) -> None: - self._snapshot = runtime_state.dumps() - - def load(self) -> GraphRuntimeState: - assert self._snapshot is not None - return GraphRuntimeState.from_snapshot(self._snapshot) - - -@dataclass -class StaticForm(HumanInputFormEntity): - form_id: str - rendered: str - is_submitted: bool - action_id: str | None = None - data: Mapping[str, Any] | None = None - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() + timedelta(days=1) - - @property - def id(self) -> str: - return self.form_id - - @property - def web_app_token(self) -> str | None: - return "token" - - @property - def recipients(self) -> list: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class StaticRepo(HumanInputFormRepository): - def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: - self._forms_by_node_id = dict(forms_by_node_id) - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_node_id.get(node_id) - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - raise AssertionError("create_form should not be called in resume scenario") - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config=graph_config, - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ) - - human_a_config = {"id": "human_a", "data": human_data.model_dump()} - human_a = HumanInputNode( - id=human_a_config["id"], - config=human_a_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - ) - - human_b_config = {"id": "human_b", "data": human_data.model_dump()} - human_b = HumanInputNode( - id=human_b_config["id"], - config=human_b_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - ) - - end_data = EndNodeData( - title="End", - outputs=[ - OutputVariableEntity(variable="res_a", value_selector=["human_a", "__action_id"]), - OutputVariableEntity(variable="res_b", value_selector=["human_b", "__action_id"]), - ], - desc=None, - ) - end_config = {"id": "end", "data": end_data.model_dump()} - end_node = EndNode( - id=end_config["id"], - config=end_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - builder = ( - Graph.new() - .add_root(start_node) - .add_node(human_a, from_node_id="start") - .add_node(human_b, from_node_id="start") - .add_node(end_node, from_node_id="human_a", source_handle="approve") - ) - return builder.connect(tail="human_b", head="end", source_handle="approve").build() - - -def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[object]: - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - return list(engine.run()) - - -def _form(submitted: bool, action_id: str | None) -> StaticForm: - return StaticForm( - form_id="form", - rendered="rendered", - is_submitted=submitted, - action_id=action_id, - data={}, - status_value=HumanInputFormStatus.SUBMITTED if submitted else HumanInputFormStatus.WAITING, - ) - - -def test_parallel_human_input_join_completes_after_second_resume() -> None: - pause_store: PauseStateStore = InMemoryPauseStore() - - initial_state = _build_runtime_state() - initial_repo = StaticRepo( - { - "human_a": _form(submitted=False, action_id=None), - "human_b": _form(submitted=False, action_id=None), - } - ) - initial_graph = _build_graph(initial_state, initial_repo) - initial_events = _run_graph(initial_graph, initial_state) - - assert isinstance(initial_events[-1], GraphRunPausedEvent) - pause_store.save(initial_state) - - first_resume_state = pause_store.load() - first_resume_repo = StaticRepo( - { - "human_a": _form(submitted=True, action_id="approve"), - "human_b": _form(submitted=False, action_id=None), - } - ) - first_resume_graph = _build_graph(first_resume_state, first_resume_repo) - first_resume_events = _run_graph(first_resume_graph, first_resume_state) - - assert isinstance(first_resume_events[0], GraphRunStartedEvent) - assert first_resume_events[0].reason is WorkflowStartReason.RESUMPTION - assert isinstance(first_resume_events[-1], GraphRunPausedEvent) - pause_store.save(first_resume_state) - - second_resume_state = pause_store.load() - second_resume_repo = StaticRepo( - { - "human_a": _form(submitted=True, action_id="approve"), - "human_b": _form(submitted=True, action_id="approve"), - } - ) - second_resume_graph = _build_graph(second_resume_state, second_resume_repo) - second_resume_events = _run_graph(second_resume_graph, second_resume_state) - - assert isinstance(second_resume_events[0], GraphRunStartedEvent) - assert second_resume_events[0].reason is WorkflowStartReason.RESUMPTION - assert isinstance(second_resume_events[-1], GraphRunSucceededEvent) - assert any(isinstance(event, NodeRunSucceededEvent) and event.node_id == "end" for event in second_resume_events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py deleted file mode 100644 index 62aa56fc57..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py +++ /dev/null @@ -1,333 +0,0 @@ -import time -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.config import GraphEngineConfig -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now - -from .test_mock_config import MockConfig, NodeMockConfig -from .test_mock_nodes import MockLLMNode - - -@dataclass -class StaticForm(HumanInputFormEntity): - form_id: str - rendered: str - is_submitted: bool - action_id: str | None = None - data: Mapping[str, Any] | None = None - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() + timedelta(days=1) - - @property - def id(self) -> str: - return self.form_id - - @property - def web_app_token(self) -> str | None: - return "token" - - @property - def recipients(self) -> list: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class StaticRepo(HumanInputFormRepository): - def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: - self._forms_by_node_id = dict(forms_by_node_id) - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_node_id.get(node_id) - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - raise AssertionError("create_form should not be called in resume scenario") - - -class DelayedHumanInputNode(HumanInputNode): - def __init__(self, delay_seconds: float, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._delay_seconds = delay_seconds - - def _run(self): - if self._delay_seconds > 0: - time.sleep(self._delay_seconds) - yield from super()._run() - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config=graph_config, - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ) - - human_a_config = {"id": "human_a", "data": human_data.model_dump()} - human_a = HumanInputNode( - id=human_a_config["id"], - config=human_a_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - ) - - human_b_config = {"id": "human_b", "data": human_data.model_dump()} - human_b = DelayedHumanInputNode( - id=human_b_config["id"], - config=human_b_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - delay_seconds=0.2, - ) - - llm_data = LLMNodeData( - title="LLM A", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt A", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_config = {"id": "llm_a", "data": llm_data.model_dump()} - llm_a = MockLLMNode( - id=llm_config["id"], - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(human_a, from_node_id="start") - .add_node(human_b, from_node_id="start") - .add_node(llm_a, from_node_id="human_a", source_handle="approve") - .build() - ) - - -def test_parallel_human_input_pause_preserves_node_finished() -> None: - runtime_state = _build_runtime_state() - - runtime_state.graph_execution.start() - runtime_state.register_paused_node("human_a") - runtime_state.register_paused_node("human_b") - - submitted = StaticForm( - form_id="form-a", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - pending = StaticForm( - form_id="form-b", - rendered="rendered", - is_submitted=False, - action_id=None, - data=None, - status_value=HumanInputFormStatus.WAITING, - ) - repo = StaticRepo({"human_a": submitted, "human_b": pending}) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - - graph = _build_graph(runtime_state, repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - events = list(engine.run()) - - llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) - llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) - human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) - graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) - graph_started = any(isinstance(e, GraphRunStartedEvent) for e in events) - - assert graph_started - assert graph_paused - assert human_b_pause - assert llm_started - assert llm_succeeded - - -def test_parallel_human_input_pause_preserves_node_finished_after_snapshot_resume() -> None: - base_state = _build_runtime_state() - base_state.graph_execution.start() - base_state.register_paused_node("human_a") - base_state.register_paused_node("human_b") - snapshot = base_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - submitted = StaticForm( - form_id="form-a", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - pending = StaticForm( - form_id="form-b", - rendered="rendered", - is_submitted=False, - action_id=None, - data=None, - status_value=HumanInputFormStatus.WAITING, - ) - repo = StaticRepo({"human_a": submitted, "human_b": pending}) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - - graph = _build_graph(resumed_state, repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=resumed_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - events = list(engine.run()) - - start_event = next(e for e in events if isinstance(e, GraphRunStartedEvent)) - assert start_event.reason is WorkflowStartReason.RESUMPTION - - llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) - llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) - human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) - graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) - - assert graph_paused - assert human_b_pause - assert llm_started - assert llm_succeeded diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py deleted file mode 100644 index 156cfefcd6..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py +++ /dev/null @@ -1,309 +0,0 @@ -import time -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.config import GraphEngineConfig -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now - -from .test_mock_config import MockConfig, NodeMockConfig -from .test_mock_nodes import MockLLMNode - - -@dataclass -class StaticForm(HumanInputFormEntity): - form_id: str - rendered: str - is_submitted: bool - action_id: str | None = None - data: Mapping[str, Any] | None = None - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() + timedelta(days=1) - - @property - def id(self) -> str: - return self.form_id - - @property - def web_app_token(self) -> str | None: - return "token" - - @property - def recipients(self) -> list: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class StaticRepo(HumanInputFormRepository): - def __init__(self, form: HumanInputFormEntity) -> None: - self._form = form - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - if node_id != "human_pause": - return None - return self._form - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - raise AssertionError("create_form should not be called in this test") - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config=graph_config, - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - llm_a_data = LLMNodeData( - title="LLM A", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt A", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_a_config = {"id": "llm_a", "data": llm_a_data.model_dump()} - llm_a = MockLLMNode( - id=llm_a_config["id"], - config=llm_a_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - llm_b_data = LLMNodeData( - title="LLM B", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt B", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_b_config = {"id": "llm_b", "data": llm_b_data.model_dump()} - llm_b = MockLLMNode( - id=llm_b_config["id"], - config=llm_b_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Pause here", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ) - human_config = {"id": "human_pause", "data": human_data.model_dump()} - human_node = HumanInputNode( - id=human_config["id"], - config=human_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - ) - - end_human_data = EndNodeData(title="End Human", outputs=[], desc=None) - end_human_config = {"id": "end_human", "data": end_human_data.model_dump()} - end_human = EndNode( - id=end_human_config["id"], - config=end_human_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(llm_a, from_node_id="start") - .add_node(human_node, from_node_id="start") - .add_node(llm_b, from_node_id="llm_a") - .add_node(end_human, from_node_id="human_pause", source_handle="approve") - .build() - ) - - -def _get_node_started_event(events: list[object], node_id: str) -> NodeRunStartedEvent | None: - for event in events: - if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: - return event - return None - - -def test_pause_defers_ready_nodes_until_resume() -> None: - runtime_state = _build_runtime_state() - - paused_form = StaticForm( - form_id="form-pause", - rendered="rendered", - is_submitted=False, - status_value=HumanInputFormStatus.WAITING, - ) - pause_repo = StaticRepo(paused_form) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - mock_config.set_node_config( - "llm_b", - NodeMockConfig(node_id="llm_b", outputs={"text": "LLM B output"}, delay=0.0), - ) - - graph = _build_graph(runtime_state, pause_repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - paused_events = list(engine.run()) - - assert any(isinstance(e, GraphRunPausedEvent) for e in paused_events) - assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in paused_events) - assert _get_node_started_event(paused_events, "llm_b") is None - - snapshot = runtime_state.dumps() - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - submitted_form = StaticForm( - form_id="form-pause", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - resume_repo = StaticRepo(submitted_form) - - resumed_graph = _build_graph(resumed_state, resume_repo, mock_config) - resumed_engine = GraphEngine( - workflow_id="workflow", - graph=resumed_graph, - graph_runtime_state=resumed_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - resumed_events = list(resumed_engine.run()) - - start_event = next(e for e in resumed_events if isinstance(e, GraphRunStartedEvent)) - assert start_event.reason is WorkflowStartReason.RESUMPTION - - llm_b_started = _get_node_started_event(resumed_events, "llm_b") - assert llm_b_started is not None - assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_b" for e in resumed_events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py deleted file mode 100644 index 700b3f4b8b..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ /dev/null @@ -1,217 +0,0 @@ -import datetime -import time -from typing import Any -from unittest.mock import MagicMock - -from core.workflow.entities import GraphInitParams -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.graph_events import ( - GraphEngineEvent, - GraphRunPausedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from core.workflow.graph_events.graph import GraphRunStartedEvent -from core.workflow.nodes.base.entities import OutputVariableEntity -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="test-execution-id", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = True - form_entity.selected_action_id = action_id - form_entity.submitted_data = {} - form_entity.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - repo.get_form.return_value = form_entity - return repo - - -def _mock_form_repository_without_submission() -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = False - repo.create_form.return_value = form_entity - repo.get_form.return_value = None - return repo - - -def _build_human_input_graph( - runtime_state: GraphRuntimeState, - form_repository: HumanInputFormRepository, -) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config=graph_config, - user_id="user", - user_from="account", - invoke_from="service-api", - call_depth=0, - ) - - start_data = StartNodeData(title="start", variables=[]) - start_node = StartNode( - id="start", - config={"id": "start", "data": start_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="human", - form_content="Awaiting human input", - inputs=[], - user_actions=[ - UserAction(id="continue", title="Continue"), - ], - ) - human_node = HumanInputNode( - id="human", - config={"id": "human", "data": human_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - form_repository=form_repository, - ) - - end_data = EndNodeData( - title="end", - outputs=[ - OutputVariableEntity(variable="result", value_selector=["human", "action_id"]), - ], - desc=None, - ) - end_node = EndNode( - id="end", - config={"id": "end", "data": end_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(human_node) - .add_node(end_node, from_node_id="human", source_handle="continue") - .build() - ) - - -def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]: - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - ) - return list(engine.run()) - - -def _node_successes(events: list[GraphEngineEvent]) -> list[str]: - return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)] - - -def _node_start_event(events: list[GraphEngineEvent], node_id: str) -> NodeRunStartedEvent | None: - for event in events: - if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: - return event - return None - - -def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any: - segment = variable_pool.get(selector) - assert segment is not None - return getattr(segment, "value", segment) - - -def test_engine_resume_restores_state_and_completion(): - # Baseline run without pausing - baseline_state = _build_runtime_state() - baseline_repo = _mock_form_repository_with_submission(action_id="continue") - baseline_graph = _build_human_input_graph(baseline_state, baseline_repo) - baseline_events = _run_graph(baseline_graph, baseline_state) - assert baseline_events - first_paused_event = baseline_events[0] - assert isinstance(first_paused_event, GraphRunStartedEvent) - assert first_paused_event.reason is WorkflowStartReason.INITIAL - assert isinstance(baseline_events[-1], GraphRunSucceededEvent) - baseline_success_nodes = _node_successes(baseline_events) - - # Run with pause - paused_state = _build_runtime_state() - pause_repo = _mock_form_repository_without_submission() - paused_graph = _build_human_input_graph(paused_state, pause_repo) - paused_events = _run_graph(paused_graph, paused_state) - assert paused_events - first_paused_event = paused_events[0] - assert isinstance(first_paused_event, GraphRunStartedEvent) - assert first_paused_event.reason is WorkflowStartReason.INITIAL - assert isinstance(paused_events[-1], GraphRunPausedEvent) - snapshot = paused_state.dumps() - - # Resume from snapshot - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resume_repo = _mock_form_repository_with_submission(action_id="continue") - resumed_graph = _build_human_input_graph(resumed_state, resume_repo) - resumed_events = _run_graph(resumed_graph, resumed_state) - assert resumed_events - first_resumed_event = resumed_events[0] - assert isinstance(first_resumed_event, GraphRunStartedEvent) - assert first_resumed_event.reason is WorkflowStartReason.RESUMPTION - assert isinstance(resumed_events[-1], GraphRunSucceededEvent) - - combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events) - assert combined_success_nodes == baseline_success_nodes - - paused_human_started = _node_start_event(paused_events, "human") - resumed_human_started = _node_start_event(resumed_events, "human") - assert paused_human_started is not None - assert resumed_human_started is not None - assert paused_human_started.id == resumed_human_started.id - - assert baseline_state.outputs == resumed_state.outputs - assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value( - resumed_state.variable_pool, ("human", "__action_id") - ) - assert baseline_state.graph_execution.completed - assert resumed_state.graph_execution.completed diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 21a642c2f8..488b47761b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -7,7 +7,6 @@ from core.workflow.nodes.base.node import Node # Ensures that all node classes are imported. from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING -# Ensure `NODE_TYPE_CLASSES_MAPPING` is used and not automatically removed. _ = NODE_TYPE_CLASSES_MAPPING @@ -46,9 +45,7 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined assert isinstance(cls.node_type, NodeType) assert isinstance(node_version, str) node_type_and_version = (node_type, node_version) - assert node_type_and_version not in type_version_set, ( - f"Duplicate node type and version for class: {cls=} {node_type_and_version=}" - ) + assert node_type_and_version not in type_version_set type_version_set.add(node_type_and_version) diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/__init__.py b/api/tests/unit_tests/core/workflow/nodes/human_input/__init__.py deleted file mode 100644 index 20807e9ef9..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Unit tests for human input node diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py deleted file mode 100644 index ca4a887d20..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ /dev/null @@ -1,16 +0,0 @@ -from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients -from core.workflow.runtime import VariablePool - - -def test_render_body_template_replaces_variable_values(): - config = EmailDeliveryConfig( - recipients=EmailRecipients(), - subject="Subject", - body="Hello {{#node1.value#}} {{#url#}}", - ) - variable_pool = VariablePool() - variable_pool.add(["node1", "value"], "World") - - result = config.render_body_template(body=config.body, url="https://example.com", variable_pool=variable_pool) - - assert result == "Hello World https://example.com" diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py deleted file mode 100644 index bfe7b03c13..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ /dev/null @@ -1,597 +0,0 @@ -""" -Unit tests for human input node entities. -""" - -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest -from pydantic import ValidationError - -from core.workflow.entities import GraphInitParams -from core.workflow.node_events import PauseRequestedEvent -from core.workflow.node_events.node import StreamCompletedEvent -from core.workflow.nodes.human_input.entities import ( - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, - FormInput, - FormInputDefault, - HumanInputNodeData, - MemberRecipient, - UserAction, - WebAppDeliveryMethod, - _WebAppDeliveryConfig, -) -from core.workflow.nodes.human_input.enums import ( - ButtonStyle, - DeliveryMethodType, - EmailRecipientType, - FormInputType, - PlaceholderType, - TimeoutUnit, -) -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository - - -class TestDeliveryMethod: - """Test DeliveryMethod entity.""" - - def test_webapp_delivery_method(self): - """Test webapp delivery method creation.""" - delivery_method = WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig()) - - assert delivery_method.type == DeliveryMethodType.WEBAPP - assert delivery_method.enabled is True - assert isinstance(delivery_method.config, _WebAppDeliveryConfig) - - def test_email_delivery_method(self): - """Test email delivery method creation.""" - recipients = EmailRecipients( - whole_workspace=False, - items=[ - MemberRecipient(type=EmailRecipientType.MEMBER, user_id="test-user-123"), - ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com"), - ], - ) - - config = EmailDeliveryConfig( - recipients=recipients, subject="Test Subject", body="Test body with {{#url#}} placeholder" - ) - - delivery_method = EmailDeliveryMethod(enabled=True, config=config) - - assert delivery_method.type == DeliveryMethodType.EMAIL - assert delivery_method.enabled is True - assert isinstance(delivery_method.config, EmailDeliveryConfig) - assert delivery_method.config.subject == "Test Subject" - assert len(delivery_method.config.recipients.items) == 2 - - -class TestFormInput: - """Test FormInput entity.""" - - def test_text_input_with_constant_default(self): - """Test text input with constant default value.""" - default = FormInputDefault(type=PlaceholderType.CONSTANT, value="Enter your response here...") - - form_input = FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="user_input", default=default) - - assert form_input.type == FormInputType.TEXT_INPUT - assert form_input.output_variable_name == "user_input" - assert form_input.default.type == PlaceholderType.CONSTANT - assert form_input.default.value == "Enter your response here..." - - def test_text_input_with_variable_default(self): - """Test text input with variable default value.""" - default = FormInputDefault(type=PlaceholderType.VARIABLE, selector=["node_123", "output_var"]) - - form_input = FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="user_input", default=default) - - assert form_input.default.type == PlaceholderType.VARIABLE - assert form_input.default.selector == ["node_123", "output_var"] - - def test_form_input_without_default(self): - """Test form input without default value.""" - form_input = FormInput(type=FormInputType.PARAGRAPH, output_variable_name="description") - - assert form_input.type == FormInputType.PARAGRAPH - assert form_input.output_variable_name == "description" - assert form_input.default is None - - -class TestUserAction: - """Test UserAction entity.""" - - def test_user_action_creation(self): - """Test user action creation.""" - action = UserAction(id="approve", title="Approve", button_style=ButtonStyle.PRIMARY) - - assert action.id == "approve" - assert action.title == "Approve" - assert action.button_style == ButtonStyle.PRIMARY - - def test_user_action_default_button_style(self): - """Test user action with default button style.""" - action = UserAction(id="cancel", title="Cancel") - - assert action.button_style == ButtonStyle.DEFAULT - - def test_user_action_length_boundaries(self): - """Test user action id and title length boundaries.""" - action = UserAction(id="a" * 20, title="b" * 20) - - assert action.id == "a" * 20 - assert action.title == "b" * 20 - - @pytest.mark.parametrize( - ("field_name", "value"), - [ - ("id", "a" * 21), - ("title", "b" * 21), - ], - ) - def test_user_action_length_limits(self, field_name: str, value: str): - """User action fields should enforce max length.""" - data = {"id": "approve", "title": "Approve"} - data[field_name] = value - - with pytest.raises(ValidationError) as exc_info: - UserAction(**data) - - errors = exc_info.value.errors() - assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors) - - -class TestHumanInputNodeData: - """Test HumanInputNodeData entity.""" - - def test_valid_node_data_creation(self): - """Test creating valid human input node data.""" - delivery_methods = [WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig())] - - inputs = [ - FormInput( - type=FormInputType.TEXT_INPUT, - output_variable_name="content", - default=FormInputDefault(type=PlaceholderType.CONSTANT, value="Enter content..."), - ) - ] - - user_actions = [UserAction(id="submit", title="Submit", button_style=ButtonStyle.PRIMARY)] - - node_data = HumanInputNodeData( - title="Human Input Test", - desc="Test node description", - delivery_methods=delivery_methods, - form_content="# Test Form\n\nPlease provide input:\n\n{{#$output.content#}}", - inputs=inputs, - user_actions=user_actions, - timeout=24, - timeout_unit=TimeoutUnit.HOUR, - ) - - assert node_data.title == "Human Input Test" - assert node_data.desc == "Test node description" - assert len(node_data.delivery_methods) == 1 - assert node_data.form_content.startswith("# Test Form") - assert len(node_data.inputs) == 1 - assert len(node_data.user_actions) == 1 - assert node_data.timeout == 24 - assert node_data.timeout_unit == TimeoutUnit.HOUR - - def test_node_data_with_multiple_delivery_methods(self): - """Test node data with multiple delivery methods.""" - delivery_methods = [ - WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig()), - EmailDeliveryMethod( - enabled=False, # Disabled method should be fine - config=EmailDeliveryConfig( - subject="Hi there", body="", recipients=EmailRecipients(whole_workspace=True) - ), - ), - ] - - node_data = HumanInputNodeData( - title="Test Node", delivery_methods=delivery_methods, timeout=1, timeout_unit=TimeoutUnit.DAY - ) - - assert len(node_data.delivery_methods) == 2 - assert node_data.timeout == 1 - assert node_data.timeout_unit == TimeoutUnit.DAY - - def test_node_data_defaults(self): - """Test node data with default values.""" - node_data = HumanInputNodeData(title="Test Node") - - assert node_data.title == "Test Node" - assert node_data.desc is None - assert node_data.delivery_methods == [] - assert node_data.form_content == "" - assert node_data.inputs == [] - assert node_data.user_actions == [] - assert node_data.timeout == 36 - assert node_data.timeout_unit == TimeoutUnit.HOUR - - def test_duplicate_input_output_variable_name_raises_validation_error(self): - """Duplicate form input output_variable_name should raise validation error.""" - duplicate_inputs = [ - FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"), - FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"), - ] - - with pytest.raises(ValidationError, match="duplicated output_variable_name 'content'"): - HumanInputNodeData(title="Test Node", inputs=duplicate_inputs) - - def test_duplicate_user_action_ids_raise_validation_error(self): - """Duplicate user action ids should raise validation error.""" - duplicate_actions = [ - UserAction(id="submit", title="Submit"), - UserAction(id="submit", title="Submit Again"), - ] - - with pytest.raises(ValidationError, match="duplicated user action id 'submit'"): - HumanInputNodeData(title="Test Node", user_actions=duplicate_actions) - - def test_extract_outputs_field_names(self): - content = r"""This is titile {{#start.title#}} - - A content is required: - - {{#$output.content#}} - - A ending is required: - - {{#$output.ending#}} - """ - - node_data = HumanInputNodeData(title="Human Input", form_content=content) - field_names = node_data.outputs_field_names() - assert field_names == ["content", "ending"] - - -class TestRecipients: - """Test email recipient entities.""" - - def test_member_recipient(self): - """Test member recipient creation.""" - recipient = MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123") - - assert recipient.type == EmailRecipientType.MEMBER - assert recipient.user_id == "user-123" - - def test_external_recipient(self): - """Test external recipient creation.""" - recipient = ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com") - - assert recipient.type == EmailRecipientType.EXTERNAL - assert recipient.email == "test@example.com" - - def test_email_recipients_whole_workspace(self): - """Test email recipients with whole workspace enabled.""" - recipients = EmailRecipients( - whole_workspace=True, items=[MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")] - ) - - assert recipients.whole_workspace is True - assert len(recipients.items) == 1 # Items are preserved even when whole_workspace is True - - def test_email_recipients_specific_users(self): - """Test email recipients with specific users.""" - recipients = EmailRecipients( - whole_workspace=False, - items=[ - MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123"), - ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="external@example.com"), - ], - ) - - assert recipients.whole_workspace is False - assert len(recipients.items) == 2 - assert recipients.items[0].user_id == "user-123" - assert recipients.items[1].email == "external@example.com" - - -class TestHumanInputNodeVariableResolution: - """Tests for resolving variable-based defaults in HumanInputNode.""" - - def test_resolves_variable_defaults(self): - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - variable_pool.add(("start", "name"), "Jane Doe") - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - node_data = HumanInputNodeData( - title="Human Input", - form_content="Provide your name", - inputs=[ - FormInput( - type=FormInputType.TEXT_INPUT, - output_variable_name="user_name", - default=FormInputDefault(type=PlaceholderType.VARIABLE, selector=["start", "name"]), - ), - FormInput( - type=FormInputType.TEXT_INPUT, - output_variable_name="user_email", - default=FormInputDefault(type=PlaceholderType.CONSTANT, value="foo@example.com"), - ), - ], - user_actions=[UserAction(id="submit", title="Submit")], - ) - config = {"id": "human", "data": node_data.model_dump()} - - mock_repo = MagicMock(spec=HumanInputFormRepository) - mock_repo.get_form.return_value = None - mock_repo.create_form.return_value = SimpleNamespace( - id="form-1", - rendered_content="Provide your name", - web_app_token="token", - recipients=[], - submitted=False, - ) - - node = HumanInputNode( - id=config["id"], - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=mock_repo, - ) - - run_result = node._run() - pause_event = next(run_result) - - assert isinstance(pause_event, PauseRequestedEvent) - expected_values = {"user_name": "Jane Doe"} - assert pause_event.reason.resolved_default_values == expected_values - - params = mock_repo.create_form.call_args.args[0] - assert params.resolved_default_values == expected_values - - def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self): - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-2", - ), - user_inputs={}, - conversation_variables=[], - ) - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - node_data = HumanInputNodeData( - title="Human Input", - form_content="Provide your name", - inputs=[], - user_actions=[UserAction(id="submit", title="Submit")], - ) - config = {"id": "human", "data": node_data.model_dump()} - - mock_repo = MagicMock(spec=HumanInputFormRepository) - mock_repo.get_form.return_value = None - mock_repo.create_form.return_value = SimpleNamespace( - id="form-2", - rendered_content="Provide your name", - web_app_token="console-token", - recipients=[SimpleNamespace(token="recipient-token")], - submitted=False, - ) - - node = HumanInputNode( - id=config["id"], - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=mock_repo, - ) - - run_result = node._run() - pause_event = next(run_result) - - assert isinstance(pause_event, PauseRequestedEvent) - assert pause_event.reason.form_token == "console-token" - - def test_debugger_debug_mode_overrides_email_recipients(self): - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user-123", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-3", - ), - user_inputs={}, - conversation_variables=[], - ) - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config={"nodes": [], "edges": []}, - user_id="user-123", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - node_data = HumanInputNodeData( - title="Human Input", - form_content="Provide your name", - inputs=[], - user_actions=[UserAction(id="submit", title="Submit")], - delivery_methods=[ - EmailDeliveryMethod( - enabled=True, - config=EmailDeliveryConfig( - recipients=EmailRecipients( - whole_workspace=False, - items=[ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="target@example.com")], - ), - subject="Subject", - body="Body", - debug_mode=True, - ), - ) - ], - ) - config = {"id": "human", "data": node_data.model_dump()} - - mock_repo = MagicMock(spec=HumanInputFormRepository) - mock_repo.get_form.return_value = None - mock_repo.create_form.return_value = SimpleNamespace( - id="form-3", - rendered_content="Provide your name", - web_app_token="token", - recipients=[], - submitted=False, - ) - - node = HumanInputNode( - id=config["id"], - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=mock_repo, - ) - - run_result = node._run() - pause_event = next(run_result) - assert isinstance(pause_event, PauseRequestedEvent) - - params = mock_repo.create_form.call_args.args[0] - assert len(params.delivery_methods) == 1 - method = params.delivery_methods[0] - assert isinstance(method, EmailDeliveryMethod) - assert method.config.debug_mode is True - assert method.config.recipients.whole_workspace is False - assert len(method.config.recipients.items) == 1 - recipient = method.config.recipients.items[0] - assert isinstance(recipient, MemberRecipient) - assert recipient.user_id == "user-123" - - -class TestValidation: - """Test validation scenarios.""" - - def test_invalid_form_input_type(self): - """Test validation with invalid form input type.""" - with pytest.raises(ValidationError): - FormInput( - type="invalid-type", # Invalid type - output_variable_name="test", - ) - - def test_invalid_button_style(self): - """Test validation with invalid button style.""" - with pytest.raises(ValidationError): - UserAction( - id="test", - title="Test", - button_style="invalid-style", # Invalid style - ) - - def test_invalid_timeout_unit(self): - """Test validation with invalid timeout unit.""" - with pytest.raises(ValidationError): - HumanInputNodeData( - title="Test", - timeout_unit="invalid-unit", # Invalid unit - ) - - -class TestHumanInputNodeRenderedContent: - """Tests for rendering submitted content.""" - - def test_replaces_outputs_placeholders_after_submission(self): - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - node_data = HumanInputNodeData( - title="Human Input", - form_content="Name: {{#$output.name#}}", - inputs=[ - FormInput( - type=FormInputType.TEXT_INPUT, - output_variable_name="name", - ) - ], - user_actions=[UserAction(id="approve", title="Approve")], - ) - config = {"id": "human", "data": node_data.model_dump()} - - form_repository = InMemoryHumanInputFormRepository() - node = HumanInputNode( - id=config["id"], - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=form_repository, - ) - - pause_gen = node._run() - pause_event = next(pause_gen) - assert isinstance(pause_event, PauseRequestedEvent) - with pytest.raises(StopIteration): - next(pause_gen) - - form_repository.set_submission(action_id="approve", form_data={"name": "Alice"}) - - events = list(node._run()) - last_event = events[-1] - assert isinstance(last_event, StreamCompletedEvent) - node_run_result = last_event.node_run_result - assert node_run_result.outputs["__rendered_content"] == "Name: Alice" diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py deleted file mode 100644 index a19ee4dee3..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ /dev/null @@ -1,172 +0,0 @@ -import datetime -from types import SimpleNamespace - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.graph_init_params import GraphInitParams -from core.workflow.enums import NodeType -from core.workflow.graph_events import ( - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunStartedEvent, -) -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now -from models.enums import UserFrom - - -class _FakeFormRepository: - def __init__(self, form): - self._form = form - - def get_form(self, *_args, **_kwargs): - return self._form - - -def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode: - system_variables = SystemVariable.default() - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), - start_at=0.0, - ) - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ) - - config = { - "id": "node-1", - "type": NodeType.HUMAN_INPUT.value, - "data": { - "title": "Human Input", - "form_content": form_content, - "inputs": [ - { - "type": "text_input", - "output_variable_name": "name", - "default": {"type": "constant", "value": ""}, - } - ], - "user_actions": [ - { - "id": "Accept", - "title": "Approve", - "button_style": "default", - } - ], - }, - } - - fake_form = SimpleNamespace( - id="form-1", - rendered_content=form_content, - submitted=True, - selected_action_id="Accept", - submitted_data={"name": "Alice"}, - status=HumanInputFormStatus.SUBMITTED, - expiration_time=naive_utc_now() + datetime.timedelta(days=1), - ) - - repo = _FakeFormRepository(fake_form) - return HumanInputNode( - id="node-1", - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - form_repository=repo, - ) - - -def _build_timeout_node() -> HumanInputNode: - system_variables = SystemVariable.default() - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), - start_at=0.0, - ) - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", - graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ) - - config = { - "id": "node-1", - "type": NodeType.HUMAN_INPUT.value, - "data": { - "title": "Human Input", - "form_content": "Please enter your name:\n\n{{#$output.name#}}", - "inputs": [ - { - "type": "text_input", - "output_variable_name": "name", - "default": {"type": "constant", "value": ""}, - } - ], - "user_actions": [ - { - "id": "Accept", - "title": "Approve", - "button_style": "default", - } - ], - }, - } - - fake_form = SimpleNamespace( - id="form-1", - rendered_content="content", - submitted=False, - selected_action_id=None, - submitted_data=None, - status=HumanInputFormStatus.TIMEOUT, - expiration_time=naive_utc_now() - datetime.timedelta(minutes=1), - ) - - repo = _FakeFormRepository(fake_form) - return HumanInputNode( - id="node-1", - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - form_repository=repo, - ) - - -def test_human_input_node_emits_form_filled_event_before_succeeded(): - node = _build_node() - - events = list(node.run()) - - assert isinstance(events[0], NodeRunStartedEvent) - assert isinstance(events[1], NodeRunHumanInputFormFilledEvent) - - filled_event = events[1] - assert filled_event.node_title == "Human Input" - assert filled_event.rendered_content.endswith("Alice") - assert filled_event.action_id == "Accept" - assert filled_event.action_text == "Approve" - - -def test_human_input_node_emits_timeout_event_before_succeeded(): - node = _build_timeout_node() - - events = list(node.run()) - - assert isinstance(events[0], NodeRunStartedEvent) - assert isinstance(events[1], NodeRunHumanInputFormTimeoutEvent) - - timeout_event = events[1] - assert timeout_event.node_title == "Human Input" diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool_conver.py b/api/tests/unit_tests/core/workflow/test_variable_pool_conver.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/unit_tests/extensions/test_celery_ssl.py b/api/tests/unit_tests/extensions/test_celery_ssl.py index 38477409bb..d3a4d69f07 100644 --- a/api/tests/unit_tests/extensions/test_celery_ssl.py +++ b/api/tests/unit_tests/extensions/test_celery_ssl.py @@ -104,7 +104,6 @@ class TestCelerySSLConfiguration: def test_celery_init_applies_ssl_to_broker_and_backend(self): """Test that SSL options are applied to both broker and backend when using Redis.""" mock_config = MagicMock() - mock_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL = 1 mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" mock_config.CELERY_BACKEND = "redis" mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0" diff --git a/api/tests/unit_tests/extensions/test_pubsub_channel.py b/api/tests/unit_tests/extensions/test_pubsub_channel.py deleted file mode 100644 index a5b41a7266..0000000000 --- a/api/tests/unit_tests/extensions/test_pubsub_channel.py +++ /dev/null @@ -1,20 +0,0 @@ -from configs import dify_config -from extensions import ext_redis -from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel -from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel - - -def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch): - monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub") - - channel = ext_redis.get_pubsub_broadcast_channel() - - assert isinstance(channel, RedisBroadcastChannel) - - -def test_get_pubsub_broadcast_channel_sharded(monkeypatch): - monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "sharded") - - channel = ext_redis.get_pubsub_broadcast_channel() - - assert isinstance(channel, ShardedRedisBroadcastChannel) diff --git a/api/tests/unit_tests/libs/_human_input/__init__.py b/api/tests/unit_tests/libs/_human_input/__init__.py deleted file mode 100644 index 66714e72f8..0000000000 --- a/api/tests/unit_tests/libs/_human_input/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Treat this directory as a package so support modules can be imported relatively. diff --git a/api/tests/unit_tests/libs/_human_input/support.py b/api/tests/unit_tests/libs/_human_input/support.py deleted file mode 100644 index bd86c13a2c..0000000000 --- a/api/tests/unit_tests/libs/_human_input/support.py +++ /dev/null @@ -1,249 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass, field -from datetime import datetime, timedelta -from typing import Any - -from core.workflow.nodes.human_input.entities import FormInput -from core.workflow.nodes.human_input.enums import TimeoutUnit - - -# Exceptions -class HumanInputError(Exception): - error_code: str = "unknown" - - def __init__(self, message: str = "", error_code: str | None = None): - super().__init__(message) - self.message = message or self.__class__.__name__ - if error_code: - self.error_code = error_code - - -class FormNotFoundError(HumanInputError): - error_code = "form_not_found" - - -class FormExpiredError(HumanInputError): - error_code = "human_input_form_expired" - - -class FormAlreadySubmittedError(HumanInputError): - error_code = "human_input_form_submitted" - - -class InvalidFormDataError(HumanInputError): - error_code = "invalid_form_data" - - -# Models -@dataclass -class HumanInputForm: - form_id: str - workflow_run_id: str - node_id: str - tenant_id: str - app_id: str | None - form_content: str - inputs: list[FormInput] - user_actions: list[dict[str, Any]] - timeout: int - timeout_unit: TimeoutUnit - form_token: str | None = None - created_at: datetime = field(default_factory=datetime.utcnow) - expires_at: datetime | None = None - submitted_at: datetime | None = None - submitted_data: dict[str, Any] | None = None - submitted_action: str | None = None - - def __post_init__(self) -> None: - if self.expires_at is None: - self.calculate_expiration() - - @property - def is_expired(self) -> bool: - return self.expires_at is not None and datetime.utcnow() > self.expires_at - - @property - def is_submitted(self) -> bool: - return self.submitted_at is not None - - def mark_submitted(self, inputs: dict[str, Any], action: str) -> None: - self.submitted_data = inputs - self.submitted_action = action - self.submitted_at = datetime.utcnow() - - def submit(self, inputs: dict[str, Any], action: str) -> None: - self.mark_submitted(inputs, action) - - def calculate_expiration(self) -> None: - start = self.created_at - if self.timeout_unit == TimeoutUnit.HOUR: - self.expires_at = start + timedelta(hours=self.timeout) - elif self.timeout_unit == TimeoutUnit.DAY: - self.expires_at = start + timedelta(days=self.timeout) - else: - raise ValueError(f"Unsupported timeout unit {self.timeout_unit}") - - def to_response_dict(self, *, include_site_info: bool) -> dict[str, Any]: - inputs_response = [ - { - "type": form_input.type.name.lower().replace("_", "-"), - "output_variable_name": form_input.output_variable_name, - } - for form_input in self.inputs - ] - response = { - "form_content": self.form_content, - "inputs": inputs_response, - "user_actions": self.user_actions, - } - if include_site_info: - response["site"] = {"app_id": self.app_id, "title": "Workflow Form"} - return response - - -@dataclass -class FormSubmissionData: - form_id: str - inputs: dict[str, Any] - action: str - submitted_at: datetime = field(default_factory=datetime.utcnow) - - @classmethod - def from_request(cls, form_id: str, request: FormSubmissionRequest) -> FormSubmissionData: # type: ignore - return cls(form_id=form_id, inputs=request.inputs, action=request.action) - - -@dataclass -class FormSubmissionRequest: - inputs: dict[str, Any] - action: str - - -# Repository -class InMemoryFormRepository: - """ - Simple in-memory repository used by unit tests. - """ - - def __init__(self): - self._forms: dict[str, HumanInputForm] = {} - - @property - def forms(self) -> dict[str, HumanInputForm]: - return self._forms - - def save(self, form: HumanInputForm) -> None: - self._forms[form.form_id] = form - - def get_by_id(self, form_id: str) -> HumanInputForm | None: - return self._forms.get(form_id) - - def get_by_token(self, token: str) -> HumanInputForm | None: - for form in self._forms.values(): - if form.form_token == token: - return form - return None - - def delete(self, form_id: str) -> None: - self._forms.pop(form_id, None) - - -# Service -class FormService: - """Service layer for managing human input forms in tests.""" - - def __init__(self, repository: InMemoryFormRepository): - self.repository = repository - - def create_form( - self, - *, - form_id: str, - workflow_run_id: str, - node_id: str, - tenant_id: str, - app_id: str | None, - form_content: str, - inputs, - user_actions, - timeout: int, - timeout_unit: TimeoutUnit, - form_token: str | None = None, - ) -> HumanInputForm: - form = HumanInputForm( - form_id=form_id, - workflow_run_id=workflow_run_id, - node_id=node_id, - tenant_id=tenant_id, - app_id=app_id, - form_content=form_content, - inputs=list(inputs), - user_actions=[{"id": action.id, "title": action.title} for action in user_actions], - timeout=timeout, - timeout_unit=timeout_unit, - form_token=form_token, - ) - form.calculate_expiration() - self.repository.save(form) - return form - - def get_form_by_id(self, form_id: str) -> HumanInputForm: - form = self.repository.get_by_id(form_id) - if form is None: - raise FormNotFoundError() - return form - - def get_form_by_token(self, token: str) -> HumanInputForm: - form = self.repository.get_by_token(token) - if form is None: - raise FormNotFoundError() - return form - - def get_form_definition(self, form_id: str, *, is_token: bool) -> dict: - form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id) - if form.is_expired: - raise FormExpiredError() - if form.is_submitted: - raise FormAlreadySubmittedError() - - definition = { - "form_content": form.form_content, - "inputs": form.inputs, - "user_actions": form.user_actions, - } - if is_token: - definition["site"] = {"title": "Workflow Form"} - return definition - - def submit_form(self, form_id: str, submission_data: FormSubmissionData, *, is_token: bool) -> None: - form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id) - if form.is_expired: - raise FormExpiredError() - if form.is_submitted: - raise FormAlreadySubmittedError() - - self._validate_submission(form=form, submission_data=submission_data) - form.mark_submitted(inputs=submission_data.inputs, action=submission_data.action) - self.repository.save(form) - - def cleanup_expired_forms(self) -> int: - expired_ids = [form_id for form_id, form in list(self.repository.forms.items()) if form.is_expired] - for form_id in expired_ids: - self.repository.delete(form_id) - return len(expired_ids) - - def _validate_submission(self, form: HumanInputForm, submission_data: FormSubmissionData) -> None: - defined_actions = {action["id"] for action in form.user_actions} - if submission_data.action not in defined_actions: - raise InvalidFormDataError(f"Invalid action: {submission_data.action}") - - missing_inputs = [] - for form_input in form.inputs: - if form_input.output_variable_name not in submission_data.inputs: - missing_inputs.append(form_input.output_variable_name) - - if missing_inputs: - raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}") - - # Extra inputs are allowed; no further validation required. diff --git a/api/tests/unit_tests/libs/_human_input/test_form_service.py b/api/tests/unit_tests/libs/_human_input/test_form_service.py deleted file mode 100644 index 15e7d41e85..0000000000 --- a/api/tests/unit_tests/libs/_human_input/test_form_service.py +++ /dev/null @@ -1,326 +0,0 @@ -""" -Unit tests for FormService. -""" - -from datetime import datetime, timedelta - -import pytest - -from core.workflow.nodes.human_input.entities import ( - FormInput, - UserAction, -) -from core.workflow.nodes.human_input.enums import ( - FormInputType, - TimeoutUnit, -) -from libs.datetime_utils import naive_utc_now - -from .support import ( - FormAlreadySubmittedError, - FormExpiredError, - FormNotFoundError, - FormService, - FormSubmissionData, - InMemoryFormRepository, - InvalidFormDataError, -) - - -class TestFormService: - """Test FormService functionality.""" - - @pytest.fixture - def repository(self): - """Create in-memory repository for testing.""" - return InMemoryFormRepository() - - @pytest.fixture - def form_service(self, repository): - """Create FormService with in-memory repository.""" - return FormService(repository) - - @pytest.fixture - def sample_form_data(self): - """Create sample form data.""" - return { - "form_id": "form-123", - "workflow_run_id": "run-456", - "node_id": "node-789", - "tenant_id": "tenant-abc", - "app_id": "app-def", - "form_content": "# Test Form\n\nInput: {{#$output.input#}}", - "inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="input", default=None)], - "user_actions": [UserAction(id="submit", title="Submit")], - "timeout": 1, - "timeout_unit": TimeoutUnit.HOUR, - "form_token": "token-xyz", - } - - def test_create_form(self, form_service, sample_form_data): - """Test form creation.""" - form = form_service.create_form(**sample_form_data) - - assert form.form_id == "form-123" - assert form.workflow_run_id == "run-456" - assert form.node_id == "node-789" - assert form.tenant_id == "tenant-abc" - assert form.app_id == "app-def" - assert form.form_token == "token-xyz" - assert form.timeout == 1 - assert form.timeout_unit == TimeoutUnit.HOUR - assert form.expires_at is not None - assert not form.is_expired - assert not form.is_submitted - - def test_get_form_by_id(self, form_service, sample_form_data): - """Test getting form by ID.""" - # Create form first - created_form = form_service.create_form(**sample_form_data) - - # Retrieve form - retrieved_form = form_service.get_form_by_id("form-123") - - assert retrieved_form.form_id == created_form.form_id - assert retrieved_form.workflow_run_id == created_form.workflow_run_id - - def test_get_form_by_id_not_found(self, form_service): - """Test getting non-existent form by ID.""" - with pytest.raises(FormNotFoundError) as exc_info: - form_service.get_form_by_id("non-existent-form") - - assert exc_info.value.error_code == "form_not_found" - - def test_get_form_by_token(self, form_service, sample_form_data): - """Test getting form by token.""" - # Create form first - created_form = form_service.create_form(**sample_form_data) - - # Retrieve form by token - retrieved_form = form_service.get_form_by_token("token-xyz") - - assert retrieved_form.form_id == created_form.form_id - assert retrieved_form.form_token == "token-xyz" - - def test_get_form_by_token_not_found(self, form_service): - """Test getting non-existent form by token.""" - with pytest.raises(FormNotFoundError) as exc_info: - form_service.get_form_by_token("non-existent-token") - - assert exc_info.value.error_code == "form_not_found" - - def test_get_form_definition_by_id(self, form_service, sample_form_data): - """Test getting form definition by ID.""" - # Create form first - form_service.create_form(**sample_form_data) - - # Get form definition - definition = form_service.get_form_definition("form-123", is_token=False) - - assert "form_content" in definition - assert "inputs" in definition - assert definition["form_content"] == "# Test Form\n\nInput: {{#$output.input#}}" - assert len(definition["inputs"]) == 1 - assert "site" not in definition # Should not include site info for ID-based access - - def test_get_form_definition_by_token(self, form_service, sample_form_data): - """Test getting form definition by token.""" - # Create form first - form_service.create_form(**sample_form_data) - - # Get form definition - definition = form_service.get_form_definition("token-xyz", is_token=True) - - assert "form_content" in definition - assert "inputs" in definition - assert "site" in definition # Should include site info for token-based access - - def test_get_form_definition_expired_form(self, form_service, sample_form_data): - """Test getting definition for expired form.""" - # Create form with past expiry - form_service.create_form(**sample_form_data) - - # Manually expire the form by modifying expiry time - form = form_service.get_form_by_id("form-123") - form.expires_at = datetime.utcnow() - timedelta(hours=1) - form_service.repository.save(form) - - # Should raise FormExpiredError - with pytest.raises(FormExpiredError) as exc_info: - form_service.get_form_definition("form-123", is_token=False) - - assert exc_info.value.error_code == "human_input_form_expired" - - def test_get_form_definition_submitted_form(self, form_service, sample_form_data): - """Test getting definition for already submitted form.""" - # Create form first - form_service.create_form(**sample_form_data) - - # Submit the form - submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit") - form_service.submit_form("form-123", submission_data, is_token=False) - - # Should raise FormAlreadySubmittedError - with pytest.raises(FormAlreadySubmittedError) as exc_info: - form_service.get_form_definition("form-123", is_token=False) - - assert exc_info.value.error_code == "human_input_form_submitted" - - def test_submit_form_success(self, form_service, sample_form_data): - """Test successful form submission.""" - # Create form first - form_service.create_form(**sample_form_data) - - # Submit form - submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit") - - # Should not raise any exception - form_service.submit_form("form-123", submission_data, is_token=False) - - # Verify form is marked as submitted - form = form_service.get_form_by_id("form-123") - assert form.is_submitted - assert form.submitted_data == {"input": "test value"} - assert form.submitted_action == "submit" - assert form.submitted_at is not None - - def test_submit_form_missing_inputs(self, form_service, sample_form_data): - """Test form submission with missing inputs.""" - # Create form first - form_service.create_form(**sample_form_data) - - # Submit form with missing required input - submission_data = FormSubmissionData( - form_id="form-123", - inputs={}, # Missing required "input" field - action="submit", - ) - - with pytest.raises(InvalidFormDataError) as exc_info: - form_service.submit_form("form-123", submission_data, is_token=False) - - assert "Missing required inputs" in exc_info.value.message - assert "input" in exc_info.value.message - - def test_submit_form_invalid_action(self, form_service, sample_form_data): - """Test form submission with invalid action.""" - # Create form first - form_service.create_form(**sample_form_data) - - # Submit form with invalid action - submission_data = FormSubmissionData( - form_id="form-123", - inputs={"input": "test value"}, - action="invalid_action", # Not in the allowed actions - ) - - with pytest.raises(InvalidFormDataError) as exc_info: - form_service.submit_form("form-123", submission_data, is_token=False) - - assert "Invalid action" in exc_info.value.message - assert "invalid_action" in exc_info.value.message - - def test_submit_form_expired(self, form_service, sample_form_data): - """Test submitting expired form.""" - # Create form first - form_service.create_form(**sample_form_data) - - # Manually expire the form - form = form_service.get_form_by_id("form-123") - form.expires_at = datetime.utcnow() - timedelta(hours=1) - form_service.repository.save(form) - - # Try to submit expired form - submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit") - - with pytest.raises(FormExpiredError) as exc_info: - form_service.submit_form("form-123", submission_data, is_token=False) - - assert exc_info.value.error_code == "human_input_form_expired" - - def test_submit_form_already_submitted(self, form_service, sample_form_data): - """Test submitting form that's already submitted.""" - # Create and submit form first - form_service.create_form(**sample_form_data) - - submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "first submission"}, action="submit") - form_service.submit_form("form-123", submission_data, is_token=False) - - # Try to submit again - second_submission = FormSubmissionData( - form_id="form-123", inputs={"input": "second submission"}, action="submit" - ) - - with pytest.raises(FormAlreadySubmittedError) as exc_info: - form_service.submit_form("form-123", second_submission, is_token=False) - - assert exc_info.value.error_code == "human_input_form_submitted" - - def test_cleanup_expired_forms(self, form_service, sample_form_data): - """Test cleanup of expired forms.""" - # Create multiple forms - for i in range(3): - data = sample_form_data.copy() - data["form_id"] = f"form-{i}" - data["form_token"] = f"token-{i}" - form_service.create_form(**data) - - # Manually expire some forms - for i in range(2): # Expire first 2 forms - form = form_service.get_form_by_id(f"form-{i}") - form.expires_at = naive_utc_now() - timedelta(hours=1) - form_service.repository.save(form) - - # Clean up expired forms - cleaned_count = form_service.cleanup_expired_forms() - - assert cleaned_count == 2 - - # Verify expired forms are gone - with pytest.raises(FormNotFoundError): - form_service.get_form_by_id("form-0") - - with pytest.raises(FormNotFoundError): - form_service.get_form_by_id("form-1") - - # Verify non-expired form still exists - form = form_service.get_form_by_id("form-2") - assert form.form_id == "form-2" - - -class TestFormValidation: - """Test form validation logic.""" - - def test_validate_submission_with_extra_inputs(self): - """Test validation allows extra inputs that aren't defined in form.""" - repository = InMemoryFormRepository() - form_service = FormService(repository) - - # Create form with one input - form_data = { - "form_id": "form-123", - "workflow_run_id": "run-456", - "node_id": "node-789", - "tenant_id": "tenant-abc", - "app_id": "app-def", - "form_content": "Test form", - "inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="required_input", default=None)], - "user_actions": [UserAction(id="submit", title="Submit")], - "timeout": 1, - "timeout_unit": TimeoutUnit.HOUR, - } - - form_service.create_form(**form_data) - - # Submit with extra input (should be allowed) - submission_data = FormSubmissionData( - form_id="form-123", - inputs={ - "required_input": "value1", - "extra_input": "value2", # Extra input not defined in form - }, - action="submit", - ) - - # Should not raise any exception - form_service.submit_form("form-123", submission_data, is_token=False) diff --git a/api/tests/unit_tests/libs/_human_input/test_models.py b/api/tests/unit_tests/libs/_human_input/test_models.py deleted file mode 100644 index 962eeb9e11..0000000000 --- a/api/tests/unit_tests/libs/_human_input/test_models.py +++ /dev/null @@ -1,232 +0,0 @@ -""" -Unit tests for human input form models. -""" - -from datetime import datetime, timedelta - -import pytest - -from core.workflow.nodes.human_input.entities import ( - FormInput, - UserAction, -) -from core.workflow.nodes.human_input.enums import ( - FormInputType, - TimeoutUnit, -) - -from .support import FormSubmissionData, FormSubmissionRequest, HumanInputForm - - -class TestHumanInputForm: - """Test HumanInputForm model.""" - - @pytest.fixture - def sample_form_data(self): - """Create sample form data.""" - return { - "form_id": "form-123", - "workflow_run_id": "run-456", - "node_id": "node-789", - "tenant_id": "tenant-abc", - "app_id": "app-def", - "form_content": "# Test Form\n\nInput: {{#$output.input#}}", - "inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="input", default=None)], - "user_actions": [UserAction(id="submit", title="Submit")], - "timeout": 2, - "timeout_unit": TimeoutUnit.HOUR, - "form_token": "token-xyz", - } - - def test_form_creation(self, sample_form_data): - """Test form creation.""" - form = HumanInputForm(**sample_form_data) - - assert form.form_id == "form-123" - assert form.workflow_run_id == "run-456" - assert form.node_id == "node-789" - assert form.tenant_id == "tenant-abc" - assert form.app_id == "app-def" - assert form.form_token == "token-xyz" - assert form.timeout == 2 - assert form.timeout_unit == TimeoutUnit.HOUR - assert form.created_at is not None - assert form.expires_at is not None - assert form.submitted_at is None - assert form.submitted_data is None - assert form.submitted_action is None - - def test_form_expiry_calculation_hours(self, sample_form_data): - """Test form expiry calculation for hours.""" - form = HumanInputForm(**sample_form_data) - - # Should expire 2 hours after creation - expected_expiry = form.created_at + timedelta(hours=2) - assert abs((form.expires_at - expected_expiry).total_seconds()) < 1 # Within 1 second - - def test_form_expiry_calculation_days(self, sample_form_data): - """Test form expiry calculation for days.""" - sample_form_data["timeout"] = 3 - sample_form_data["timeout_unit"] = TimeoutUnit.DAY - - form = HumanInputForm(**sample_form_data) - - # Should expire 3 days after creation - expected_expiry = form.created_at + timedelta(days=3) - assert abs((form.expires_at - expected_expiry).total_seconds()) < 1 # Within 1 second - - def test_form_expiry_property_not_expired(self, sample_form_data): - """Test is_expired property for non-expired form.""" - form = HumanInputForm(**sample_form_data) - assert not form.is_expired - - def test_form_expiry_property_expired(self, sample_form_data): - """Test is_expired property for expired form.""" - # Create form with past expiry - past_time = datetime.utcnow() - timedelta(hours=1) - sample_form_data["created_at"] = past_time - - form = HumanInputForm(**sample_form_data) - # Manually set expiry to past time - form.expires_at = past_time - - assert form.is_expired - - def test_form_submission_property_not_submitted(self, sample_form_data): - """Test is_submitted property for non-submitted form.""" - form = HumanInputForm(**sample_form_data) - assert not form.is_submitted - - def test_form_submission_property_submitted(self, sample_form_data): - """Test is_submitted property for submitted form.""" - form = HumanInputForm(**sample_form_data) - form.submit({"input": "test value"}, "submit") - - assert form.is_submitted - assert form.submitted_at is not None - assert form.submitted_data == {"input": "test value"} - assert form.submitted_action == "submit" - - def test_form_submit_method(self, sample_form_data): - """Test form submit method.""" - form = HumanInputForm(**sample_form_data) - - submission_time_before = datetime.utcnow() - form.submit({"input": "test value"}, "submit") - submission_time_after = datetime.utcnow() - - assert form.is_submitted - assert form.submitted_data == {"input": "test value"} - assert form.submitted_action == "submit" - assert submission_time_before <= form.submitted_at <= submission_time_after - - def test_form_to_response_dict_without_site_info(self, sample_form_data): - """Test converting form to response dict without site info.""" - form = HumanInputForm(**sample_form_data) - - response = form.to_response_dict(include_site_info=False) - - assert "form_content" in response - assert "inputs" in response - assert "site" not in response - assert response["form_content"] == "# Test Form\n\nInput: {{#$output.input#}}" - assert len(response["inputs"]) == 1 - assert response["inputs"][0]["type"] == "text-input" - assert response["inputs"][0]["output_variable_name"] == "input" - - def test_form_to_response_dict_with_site_info(self, sample_form_data): - """Test converting form to response dict with site info.""" - form = HumanInputForm(**sample_form_data) - - response = form.to_response_dict(include_site_info=True) - - assert "form_content" in response - assert "inputs" in response - assert "site" in response - assert response["site"]["app_id"] == "app-def" - assert response["site"]["title"] == "Workflow Form" - - def test_form_without_web_app_token(self, sample_form_data): - """Test form creation without web app token.""" - sample_form_data["form_token"] = None - - form = HumanInputForm(**sample_form_data) - - assert form.form_token is None - assert form.form_id == "form-123" # Other fields should still work - - def test_form_with_explicit_timestamps(self): - """Test form creation with explicit timestamps.""" - created_time = datetime(2024, 1, 15, 10, 30, 0) - expires_time = datetime(2024, 1, 15, 12, 30, 0) - - form = HumanInputForm( - form_id="form-123", - workflow_run_id="run-456", - node_id="node-789", - tenant_id="tenant-abc", - app_id="app-def", - form_content="Test content", - inputs=[], - user_actions=[], - timeout=2, - timeout_unit=TimeoutUnit.HOUR, - created_at=created_time, - expires_at=expires_time, - ) - - assert form.created_at == created_time - assert form.expires_at == expires_time - - -class TestFormSubmissionData: - """Test FormSubmissionData model.""" - - def test_submission_data_creation(self): - """Test submission data creation.""" - submission_data = FormSubmissionData( - form_id="form-123", inputs={"field1": "value1", "field2": "value2"}, action="submit" - ) - - assert submission_data.form_id == "form-123" - assert submission_data.inputs == {"field1": "value1", "field2": "value2"} - assert submission_data.action == "submit" - assert submission_data.submitted_at is not None - - def test_submission_data_from_request(self): - """Test creating submission data from API request.""" - request = FormSubmissionRequest(inputs={"input": "test value"}, action="confirm") - - submission_data = FormSubmissionData.from_request("form-456", request) - - assert submission_data.form_id == "form-456" - assert submission_data.inputs == {"input": "test value"} - assert submission_data.action == "confirm" - assert submission_data.submitted_at is not None - - def test_submission_data_with_empty_inputs(self): - """Test submission data with empty inputs.""" - submission_data = FormSubmissionData(form_id="form-123", inputs={}, action="cancel") - - assert submission_data.inputs == {} - assert submission_data.action == "cancel" - - def test_submission_data_timestamps(self): - """Test submission data timestamp handling.""" - before_time = datetime.utcnow() - - submission_data = FormSubmissionData(form_id="form-123", inputs={"test": "value"}, action="submit") - - after_time = datetime.utcnow() - - assert before_time <= submission_data.submitted_at <= after_time - - def test_submission_data_with_explicit_timestamp(self): - """Test submission data with explicit timestamp.""" - specific_time = datetime(2024, 1, 15, 14, 30, 0) - - submission_data = FormSubmissionData( - form_id="form-123", inputs={"test": "value"}, action="submit", submitted_at=specific_time - ) - - assert submission_data.submitted_at == specific_time diff --git a/api/tests/unit_tests/libs/test_helper.py b/api/tests/unit_tests/libs/test_helper.py index 1a93dbbca1..de74eff82f 100644 --- a/api/tests/unit_tests/libs/test_helper.py +++ b/api/tests/unit_tests/libs/test_helper.py @@ -1,8 +1,6 @@ -from datetime import datetime - import pytest -from libs.helper import OptionalTimestampField, escape_like_pattern, extract_tenant_id +from libs.helper import escape_like_pattern, extract_tenant_id from models.account import Account from models.model import EndUser @@ -67,19 +65,6 @@ class TestExtractTenantId: extract_tenant_id(dict_user) -class TestOptionalTimestampField: - def test_format_returns_none_for_none(self): - field = OptionalTimestampField() - - assert field.format(None) is None - - def test_format_returns_unix_timestamp_for_datetime(self): - field = OptionalTimestampField() - value = datetime(2024, 1, 2, 3, 4, 5) - - assert field.format(value) == int(value.timestamp()) - - class TestEscapeLikePattern: """Test cases for the escape_like_pattern utility function.""" diff --git a/api/tests/unit_tests/libs/test_rate_limiter.py b/api/tests/unit_tests/libs/test_rate_limiter.py deleted file mode 100644 index 9d44b07b5e..0000000000 --- a/api/tests/unit_tests/libs/test_rate_limiter.py +++ /dev/null @@ -1,68 +0,0 @@ -from unittest.mock import MagicMock - -from libs import helper as helper_module - - -class _FakeRedis: - def __init__(self) -> None: - self._zsets: dict[str, dict[str, float]] = {} - self._expiry: dict[str, int] = {} - - def zadd(self, key: str, mapping: dict[str, float]) -> int: - zset = self._zsets.setdefault(key, {}) - for member, score in mapping.items(): - zset[str(member)] = float(score) - return len(mapping) - - def zremrangebyscore(self, key: str, min_score: str | float, max_score: str | float) -> int: - zset = self._zsets.get(key, {}) - min_value = float("-inf") if min_score == "-inf" else float(min_score) - max_value = float("inf") if max_score == "+inf" else float(max_score) - to_delete = [member for member, score in zset.items() if min_value <= score <= max_value] - for member in to_delete: - del zset[member] - return len(to_delete) - - def zcard(self, key: str) -> int: - return len(self._zsets.get(key, {})) - - def expire(self, key: str, ttl: int) -> bool: - self._expiry[key] = ttl - return True - - -def test_rate_limiter_counts_attempts_within_same_second(monkeypatch): - fake_redis = _FakeRedis() - monkeypatch.setattr(helper_module.time, "time", lambda: 1000) - - limiter = helper_module.RateLimiter( - prefix="test_rate_limit", - max_attempts=2, - time_window=60, - redis_client=fake_redis, - ) - - limiter.increment_rate_limit("203.0.113.10") - limiter.increment_rate_limit("203.0.113.10") - - assert limiter.is_rate_limited("203.0.113.10") is True - - -def test_rate_limiter_uses_injected_redis(monkeypatch): - redis_client = MagicMock() - redis_client.zcard.return_value = 1 - monkeypatch.setattr(helper_module.time, "time", lambda: 1000) - - limiter = helper_module.RateLimiter( - prefix="test_rate_limit", - max_attempts=1, - time_window=60, - redis_client=redis_client, - ) - - limiter.increment_rate_limit("203.0.113.10") - limiter.is_rate_limited("203.0.113.10") - - assert redis_client.zadd.called is True - assert redis_client.zremrangebyscore.called is True - assert redis_client.zcard.called is True diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index c6dfd41803..8be2eea121 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -1296,7 +1296,6 @@ class TestConversationStatusCount: assert result["success"] == 1 # One SUCCEEDED assert result["failed"] == 1 # One FAILED assert result["partial_success"] == 1 # One PARTIAL_SUCCEEDED - assert result["paused"] == 0 def test_status_count_app_id_filtering(self): """Test that status_count filters workflow runs by app_id for security.""" @@ -1351,7 +1350,6 @@ class TestConversationStatusCount: assert result["success"] == 0 assert result["failed"] == 0 assert result["partial_success"] == 0 - assert result["paused"] == 0 def test_status_count_handles_invalid_workflow_status(self): """Test that status_count gracefully handles invalid workflow status values.""" @@ -1406,57 +1404,3 @@ class TestConversationStatusCount: assert result["success"] == 0 assert result["failed"] == 0 assert result["partial_success"] == 0 - assert result["paused"] == 0 - - def test_status_count_paused(self): - """Test status_count includes paused workflow runs.""" - # Arrange - from core.workflow.enums import WorkflowExecutionStatus - - app_id = str(uuid4()) - conversation_id = str(uuid4()) - workflow_run_id = str(uuid4()) - - conversation = Conversation( - app_id=app_id, - mode=AppMode.CHAT, - name="Test Conversation", - status="normal", - from_source="api", - ) - conversation.id = conversation_id - - mock_messages = [ - MagicMock( - conversation_id=conversation_id, - workflow_run_id=workflow_run_id, - ), - ] - - mock_workflow_runs = [ - MagicMock( - id=workflow_run_id, - status=WorkflowExecutionStatus.PAUSED.value, - app_id=app_id, - ), - ] - - with patch("models.model.db.session.scalars") as mock_scalars: - - def mock_scalars_side_effect(query): - mock_result = MagicMock() - if "messages" in str(query): - mock_result.all.return_value = mock_messages - elif "workflow_runs" in str(query): - mock_result.all.return_value = mock_workflow_runs - else: - mock_result.all.return_value = [] - return mock_result - - mock_scalars.side_effect = mock_scalars_side_effect - - # Act - result = conversation.status_count - - # Assert - assert result["paused"] == 1 diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py deleted file mode 100644 index ceb1406a4b..0000000000 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Unit tests for DifyAPISQLAlchemyWorkflowNodeExecutionRepository implementation.""" - -from unittest.mock import Mock - -from sqlalchemy.orm import Session, sessionmaker - -from repositories.sqlalchemy_api_workflow_node_execution_repository import ( - DifyAPISQLAlchemyWorkflowNodeExecutionRepository, -) - - -class TestDifyAPISQLAlchemyWorkflowNodeExecutionRepository: - def test_get_executions_by_workflow_run_keeps_paused_records(self): - mock_session = Mock(spec=Session) - execute_result = Mock() - execute_result.scalars.return_value.all.return_value = [] - mock_session.execute.return_value = execute_result - - session_maker = Mock(spec=sessionmaker) - context_manager = Mock() - context_manager.__enter__ = Mock(return_value=mock_session) - context_manager.__exit__ = Mock(return_value=None) - session_maker.return_value = context_manager - - repository = DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker) - - repository.get_executions_by_workflow_run( - tenant_id="tenant-123", - app_id="app-123", - workflow_run_id="workflow-run-123", - ) - - stmt = mock_session.execute.call_args[0][0] - where_clauses = list(getattr(stmt, "_where_criteria", []) or []) - where_strs = [str(clause).lower() for clause in where_clauses] - - assert any("tenant_id" in clause for clause in where_strs) - assert any("app_id" in clause for clause in where_strs) - assert any("workflow_run_id" in clause for clause in where_strs) - assert not any("paused" in clause for clause in where_strs) diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 4caaa056ff..d443c4c9a5 100644 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -1,6 +1,5 @@ """Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation.""" -import secrets from datetime import UTC, datetime from unittest.mock import Mock, patch @@ -8,17 +7,12 @@ import pytest from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Session, sessionmaker -from core.workflow.entities.pause_reason import HumanInputRequired, PauseReasonType from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.nodes.human_input.entities import FormDefinition, FormInput, UserAction -from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormStatus -from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType from models.workflow import WorkflowPause as WorkflowPauseModel -from models.workflow import WorkflowPauseReason, WorkflowRun +from models.workflow import WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.sqlalchemy_api_workflow_run_repository import ( DifyAPISQLAlchemyWorkflowRunRepository, - _build_human_input_required_reason, _PrivateWorkflowPauseEntity, _WorkflowRunError, ) @@ -211,11 +205,11 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): ): """Test workflow pause creation when workflow not in RUNNING status.""" # Arrange - sample_workflow_run.status = WorkflowExecutionStatus.SUCCEEDED + sample_workflow_run.status = WorkflowExecutionStatus.PAUSED mock_session.get.return_value = sample_workflow_run # Act & Assert - with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"): + with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"): repository.create_workflow_pause( workflow_run_id="workflow-run-123", state_owner_user_id="user-123", @@ -301,7 +295,6 @@ class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): sample_workflow_pause.resumed_at = None mock_session.scalar.return_value = sample_workflow_run - mock_session.scalars.return_value.all.return_value = [] with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now: mock_now.return_value = datetime.now(UTC) @@ -462,53 +455,3 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository) assert result1 == expected_state assert result2 == expected_state mock_storage.load.assert_called_once() # Only called once due to caching - - -class TestBuildHumanInputRequiredReason: - def test_prefers_backstage_token_when_available(self): - expiration_time = datetime.now(UTC) - form_definition = FormDefinition( - form_content="content", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], - user_actions=[UserAction(id="approve", title="Approve")], - rendered_content="rendered", - expiration_time=expiration_time, - default_values={"name": "Alice"}, - node_title="Ask Name", - display_in_ui=True, - ) - form_model = HumanInputForm( - id="form-1", - tenant_id="tenant-1", - app_id="app-1", - workflow_run_id="run-1", - node_id="node-1", - form_definition=form_definition.model_dump_json(), - rendered_content="rendered", - status=HumanInputFormStatus.WAITING, - expiration_time=expiration_time, - ) - reason_model = WorkflowPauseReason( - pause_id="pause-1", - type_=PauseReasonType.HUMAN_INPUT_REQUIRED, - form_id="form-1", - node_id="node-1", - message="", - ) - access_token = secrets.token_urlsafe(8) - backstage_recipient = HumanInputFormRecipient( - form_id="form-1", - delivery_id="delivery-1", - recipient_type=RecipientType.BACKSTAGE, - recipient_payload=BackstageRecipientPayload().model_dump_json(), - access_token=access_token, - ) - - reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient]) - - assert isinstance(reason, HumanInputRequired) - assert reason.form_token == access_token - assert reason.node_title == "Ask Name" - assert reason.form_content == "content" - assert reason.inputs[0].output_variable_name == "name" - assert reason.actions[0].id == "approve" diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py deleted file mode 100644 index f5428b46ff..0000000000 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py +++ /dev/null @@ -1,180 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from dataclasses import dataclass -from datetime import UTC, datetime, timedelta - -from core.entities.execution_extra_content import HumanInputContent as HumanInputContentDomain -from core.entities.execution_extra_content import HumanInputFormSubmissionData -from core.workflow.nodes.human_input.entities import ( - FormDefinition, - UserAction, -) -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from models.execution_extra_content import HumanInputContent as HumanInputContentModel -from models.human_input import ConsoleRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType -from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository - - -class _FakeScalarResult: - def __init__(self, values: Sequence[HumanInputContentModel]): - self._values = list(values) - - def all(self) -> list[HumanInputContentModel]: - return list(self._values) - - -class _FakeSession: - def __init__(self, values: Sequence[Sequence[object]]): - self._values = list(values) - - def scalars(self, _stmt): - if not self._values: - return _FakeScalarResult([]) - return _FakeScalarResult(self._values.pop(0)) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - -@dataclass -class _FakeSessionMaker: - session: _FakeSession - - def __call__(self) -> _FakeSession: - return self.session - - -def _build_form(action_id: str, action_title: str, rendered_content: str) -> HumanInputForm: - expiration_time = datetime.now(UTC) + timedelta(days=1) - definition = FormDefinition( - form_content="content", - inputs=[], - user_actions=[UserAction(id=action_id, title=action_title)], - rendered_content="rendered", - expiration_time=expiration_time, - node_title="Approval", - display_in_ui=True, - ) - form = HumanInputForm( - id=f"form-{action_id}", - tenant_id="tenant-id", - app_id="app-id", - workflow_run_id="workflow-run", - node_id="node-id", - form_definition=definition.model_dump_json(), - rendered_content=rendered_content, - status=HumanInputFormStatus.SUBMITTED, - expiration_time=expiration_time, - ) - form.selected_action_id = action_id - return form - - -def _build_content(message_id: str, action_id: str, action_title: str) -> HumanInputContentModel: - form = _build_form( - action_id=action_id, - action_title=action_title, - rendered_content=f"Rendered {action_title}", - ) - content = HumanInputContentModel( - id=f"content-{message_id}", - form_id=form.id, - message_id=message_id, - workflow_run_id=form.workflow_run_id, - ) - content.form = form - return content - - -def test_get_by_message_ids_groups_contents_by_message() -> None: - message_ids = ["msg-1", "msg-2"] - contents = [_build_content("msg-1", "approve", "Approve")] - repository = SQLAlchemyExecutionExtraContentRepository( - session_maker=_FakeSessionMaker(session=_FakeSession(values=[contents, []])) - ) - - result = repository.get_by_message_ids(message_ids) - - assert len(result) == 2 - assert [content.model_dump(mode="json", exclude_none=True) for content in result[0]] == [ - HumanInputContentDomain( - workflow_run_id="workflow-run", - submitted=True, - form_submission_data=HumanInputFormSubmissionData( - node_id="node-id", - node_title="Approval", - rendered_content="Rendered Approve", - action_id="approve", - action_text="Approve", - ), - ).model_dump(mode="json", exclude_none=True) - ] - assert result[1] == [] - - -def test_get_by_message_ids_returns_unsubmitted_form_definition() -> None: - expiration_time = datetime.now(UTC) + timedelta(days=1) - definition = FormDefinition( - form_content="content", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - rendered_content="rendered", - expiration_time=expiration_time, - default_values={"name": "John"}, - node_title="Approval", - display_in_ui=True, - ) - form = HumanInputForm( - id="form-1", - tenant_id="tenant-id", - app_id="app-id", - workflow_run_id="workflow-run", - node_id="node-id", - form_definition=definition.model_dump_json(), - rendered_content="Rendered block", - status=HumanInputFormStatus.WAITING, - expiration_time=expiration_time, - ) - content = HumanInputContentModel( - id="content-msg-1", - form_id=form.id, - message_id="msg-1", - workflow_run_id=form.workflow_run_id, - ) - content.form = form - - recipient = HumanInputFormRecipient( - form_id=form.id, - delivery_id="delivery-1", - recipient_type=RecipientType.CONSOLE, - recipient_payload=ConsoleRecipientPayload(account_id=None).model_dump_json(), - access_token="token-1", - ) - - repository = SQLAlchemyExecutionExtraContentRepository( - session_maker=_FakeSessionMaker(session=_FakeSession(values=[[content], [recipient]])) - ) - - result = repository.get_by_message_ids(["msg-1"]) - - assert len(result) == 1 - assert len(result[0]) == 1 - domain_content = result[0][0] - assert domain_content.submitted is False - assert domain_content.workflow_run_id == "workflow-run" - assert domain_content.form_definition is not None - assert domain_content.form_definition.expiration_time == int(form.expiration_time.timestamp()) - assert domain_content.form_definition is not None - form_definition = domain_content.form_definition - assert form_definition.form_id == "form-1" - assert form_definition.node_id == "node-id" - assert form_definition.node_title == "Approval" - assert form_definition.form_content == "Rendered block" - assert form_definition.display_in_ui is True - assert form_definition.form_token == "token-1" - assert form_definition.resolved_default_values == {"name": "John"} - assert form_definition.expiration_time == int(form.expiration_time.timestamp()) diff --git a/api/tests/unit_tests/services/test_app_generate_service.py b/api/tests/unit_tests/services/test_app_generate_service.py deleted file mode 100644 index 71134464e6..0000000000 --- a/api/tests/unit_tests/services/test_app_generate_service.py +++ /dev/null @@ -1,65 +0,0 @@ -from unittest.mock import MagicMock - -import services.app_generate_service as app_generate_service_module -from models.model import AppMode -from services.app_generate_service import AppGenerateService - - -class _DummyRateLimit: - def __init__(self, client_id: str, max_active_requests: int) -> None: - self.client_id = client_id - self.max_active_requests = max_active_requests - - @staticmethod - def gen_request_key() -> str: - return "dummy-request-id" - - def enter(self, request_id: str | None = None) -> str: - return request_id or "dummy-request-id" - - def exit(self, request_id: str) -> None: - return None - - def generate(self, generator, request_id: str): - return generator - - -def test_workflow_blocking_injects_pause_state_config(mocker, monkeypatch): - monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False) - mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit) - - workflow = MagicMock() - workflow.id = "workflow-id" - workflow.created_by = "owner-id" - - mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) - - generator_spy = mocker.patch( - "services.app_generate_service.WorkflowAppGenerator.generate", - return_value={"result": "ok"}, - ) - - app_model = MagicMock() - app_model.mode = AppMode.WORKFLOW - app_model.id = "app-id" - app_model.tenant_id = "tenant-id" - app_model.max_active_requests = 0 - app_model.is_agent = False - - user = MagicMock() - user.id = "user-id" - - result = AppGenerateService.generate( - app_model=app_model, - user=user, - args={"inputs": {"k": "v"}}, - invoke_from=MagicMock(), - streaming=False, - ) - - assert result == {"result": "ok"} - - call_kwargs = generator_spy.call_args.kwargs - pause_state_config = call_kwargs.get("pause_state_config") - assert pause_state_config is not None - assert pause_state_config.state_owner_user_id == "owner-id" diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index eca1d44d23..81135dbbdf 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -508,12 +508,9 @@ class TestConversationServiceMessageCreation: within conversations. """ - @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db.session") @patch("services.message_service.ConversationService.get_conversation") - def test_pagination_by_first_id_without_first_id( - self, mock_get_conversation, mock_db_session, mock_create_extra_repo - ): + def test_pagination_by_first_id_without_first_id(self, mock_get_conversation, mock_db_session): """ Test message pagination without specifying first_id. @@ -543,9 +540,6 @@ class TestConversationServiceMessageCreation: mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining mock_query.limit.return_value = mock_query # LIMIT returns self for chaining mock_query.all.return_value = messages # Final .all() returns the messages - mock_repository = MagicMock() - mock_repository.get_by_message_ids.return_value = [[] for _ in messages] - mock_create_extra_repo.return_value = mock_repository # Act - Call the pagination method without first_id result = MessageService.pagination_by_first_id( @@ -562,10 +556,9 @@ class TestConversationServiceMessageCreation: # Verify conversation was looked up with correct parameters mock_get_conversation.assert_called_once_with(app_model=app_model, user=user, conversation_id=conversation.id) - @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db.session") @patch("services.message_service.ConversationService.get_conversation") - def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session, mock_create_extra_repo): + def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session): """ Test message pagination with first_id specified. @@ -597,9 +590,6 @@ class TestConversationServiceMessageCreation: mock_query.limit.return_value = mock_query # LIMIT returns self for chaining mock_query.first.return_value = first_message # First message returned mock_query.all.return_value = messages # Remaining messages returned - mock_repository = MagicMock() - mock_repository.get_by_message_ids.return_value = [[] for _ in messages] - mock_create_extra_repo.return_value = mock_repository # Act - Call the pagination method with first_id result = MessageService.pagination_by_first_id( @@ -694,10 +684,9 @@ class TestConversationServiceMessageCreation: assert result.data == [] assert result.has_more is False - @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db.session") @patch("services.message_service.ConversationService.get_conversation") - def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session, mock_create_extra_repo): + def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session): """ Test that has_more flag is correctly set when there are more messages. @@ -727,9 +716,6 @@ class TestConversationServiceMessageCreation: mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining mock_query.limit.return_value = mock_query # LIMIT returns self for chaining mock_query.all.return_value = messages # Final .all() returns the messages - mock_repository = MagicMock() - mock_repository.get_by_message_ids.return_value = [[] for _ in messages] - mock_create_extra_repo.return_value = mock_repository # Act result = MessageService.pagination_by_first_id( @@ -744,10 +730,9 @@ class TestConversationServiceMessageCreation: assert len(result.data) == limit # Extra message should be removed assert result.has_more is True # Flag should be set - @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db.session") @patch("services.message_service.ConversationService.get_conversation") - def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session, mock_create_extra_repo): + def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session): """ Test message pagination with ascending order. @@ -776,9 +761,6 @@ class TestConversationServiceMessageCreation: mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining mock_query.limit.return_value = mock_query # LIMIT returns self for chaining mock_query.all.return_value = messages # Final .all() returns the messages - mock_repository = MagicMock() - mock_repository.get_by_message_ids.return_value = [[] for _ in messages] - mock_create_extra_repo.return_value = mock_repository # Act result = MessageService.pagination_by_first_id( diff --git a/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py b/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py deleted file mode 100644 index ab141a7b2d..0000000000 --- a/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py +++ /dev/null @@ -1,104 +0,0 @@ -from dataclasses import dataclass - -import pytest - -from enums.cloud_plan import CloudPlan -from services import feature_service as feature_service_module -from services.feature_service import FeatureModel, FeatureService - - -@dataclass(frozen=True) -class HumanInputEmailDeliveryCase: - name: str - enterprise_enabled: bool - billing_enabled: bool - tenant_id: str | None - billing_feature_enabled: bool - plan: str - expected: bool - - -CASES = [ - HumanInputEmailDeliveryCase( - name="enterprise_enabled", - enterprise_enabled=True, - billing_enabled=True, - tenant_id=None, - billing_feature_enabled=False, - plan=CloudPlan.SANDBOX, - expected=True, - ), - HumanInputEmailDeliveryCase( - name="billing_disabled", - enterprise_enabled=False, - billing_enabled=False, - tenant_id=None, - billing_feature_enabled=False, - plan=CloudPlan.SANDBOX, - expected=True, - ), - HumanInputEmailDeliveryCase( - name="billing_enabled_requires_tenant", - enterprise_enabled=False, - billing_enabled=True, - tenant_id=None, - billing_feature_enabled=True, - plan=CloudPlan.PROFESSIONAL, - expected=False, - ), - HumanInputEmailDeliveryCase( - name="billing_feature_off", - enterprise_enabled=False, - billing_enabled=True, - tenant_id="tenant-1", - billing_feature_enabled=False, - plan=CloudPlan.PROFESSIONAL, - expected=False, - ), - HumanInputEmailDeliveryCase( - name="professional_plan", - enterprise_enabled=False, - billing_enabled=True, - tenant_id="tenant-1", - billing_feature_enabled=True, - plan=CloudPlan.PROFESSIONAL, - expected=True, - ), - HumanInputEmailDeliveryCase( - name="team_plan", - enterprise_enabled=False, - billing_enabled=True, - tenant_id="tenant-1", - billing_feature_enabled=True, - plan=CloudPlan.TEAM, - expected=True, - ), - HumanInputEmailDeliveryCase( - name="sandbox_plan", - enterprise_enabled=False, - billing_enabled=True, - tenant_id="tenant-1", - billing_feature_enabled=True, - plan=CloudPlan.SANDBOX, - expected=False, - ), -] - - -@pytest.mark.parametrize("case", CASES, ids=lambda case: case.name) -def test_resolve_human_input_email_delivery_enabled_matrix( - monkeypatch: pytest.MonkeyPatch, - case: HumanInputEmailDeliveryCase, -): - monkeypatch.setattr(feature_service_module.dify_config, "ENTERPRISE_ENABLED", case.enterprise_enabled) - monkeypatch.setattr(feature_service_module.dify_config, "BILLING_ENABLED", case.billing_enabled) - features = FeatureModel() - features.billing.enabled = case.billing_feature_enabled - features.billing.subscription.plan = case.plan - - result = FeatureService._resolve_human_input_email_delivery_enabled( - features=features, - tenant_id=case.tenant_id, - ) - - assert result is case.expected diff --git a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py deleted file mode 100644 index e0d6ad1b39..0000000000 --- a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py +++ /dev/null @@ -1,97 +0,0 @@ -from types import SimpleNamespace - -import pytest - -from core.workflow.nodes.human_input.entities import ( - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, -) -from core.workflow.runtime import VariablePool -from services import human_input_delivery_test_service as service_module -from services.human_input_delivery_test_service import ( - DeliveryTestContext, - DeliveryTestError, - EmailDeliveryTestHandler, -) - - -def _make_email_method() -> EmailDeliveryMethod: - return EmailDeliveryMethod( - config=EmailDeliveryConfig( - recipients=EmailRecipients( - whole_workspace=False, - items=[ExternalRecipient(email="tester@example.com")], - ), - subject="Test subject", - body="Test body", - ) - ) - - -def test_email_delivery_test_handler_rejects_when_feature_disabled(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr( - service_module.FeatureService, - "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False), - ) - - handler = EmailDeliveryTestHandler(session_factory=object()) - context = DeliveryTestContext( - tenant_id="tenant-1", - app_id="app-1", - node_id="node-1", - node_title="Human Input", - rendered_content="content", - ) - method = _make_email_method() - - with pytest.raises(DeliveryTestError, match="Email delivery is not available"): - handler.send_test(context=context, method=method) - - -def test_email_delivery_test_handler_replaces_body_variables(monkeypatch: pytest.MonkeyPatch): - class DummyMail: - def __init__(self): - self.sent: list[dict[str, str]] = [] - - def is_inited(self) -> bool: - return True - - def send(self, *, to: str, subject: str, html: str): - self.sent.append({"to": to, "subject": subject, "html": html}) - - mail = DummyMail() - monkeypatch.setattr(service_module, "mail", mail) - monkeypatch.setattr(service_module, "render_email_template", lambda template, _substitutions: template) - monkeypatch.setattr( - service_module.FeatureService, - "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), - ) - - handler = EmailDeliveryTestHandler(session_factory=object()) - handler._resolve_recipients = lambda **_kwargs: ["tester@example.com"] # type: ignore[assignment] - - method = EmailDeliveryMethod( - config=EmailDeliveryConfig( - recipients=EmailRecipients(whole_workspace=False, items=[ExternalRecipient(email="tester@example.com")]), - subject="Subject", - body="Value {{#node1.value#}}", - ) - ) - variable_pool = VariablePool() - variable_pool.add(["node1", "value"], "OK") - context = DeliveryTestContext( - tenant_id="tenant-1", - app_id="app-1", - node_id="node-1", - node_title="Human Input", - rendered_content="content", - variable_pool=variable_pool, - ) - - handler.send_test(context=context, method=method) - - assert mail.sent[0]["html"] == "Value OK" diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py deleted file mode 100644 index d2cf74daf3..0000000000 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ /dev/null @@ -1,290 +0,0 @@ -import dataclasses -from datetime import datetime, timedelta -from unittest.mock import MagicMock - -import pytest - -import services.human_input_service as human_input_service_module -from core.repositories.human_input_repository import ( - HumanInputFormRecord, - HumanInputFormSubmissionRepository, -) -from core.workflow.nodes.human_input.entities import ( - FormDefinition, - FormInput, - UserAction, -) -from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus -from models.human_input import RecipientType -from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError -from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE - - -@pytest.fixture -def mock_session_factory(): - session = MagicMock() - session_cm = MagicMock() - session_cm.__enter__.return_value = session - session_cm.__exit__.return_value = None - - factory = MagicMock() - factory.return_value = session_cm - return factory, session - - -@pytest.fixture -def sample_form_record(): - return HumanInputFormRecord( - form_id="form-id", - workflow_run_id="workflow-run-id", - node_id="node-id", - tenant_id="tenant-id", - app_id="app-id", - form_kind=HumanInputFormKind.RUNTIME, - definition=FormDefinition( - form_content="hello", - inputs=[], - user_actions=[UserAction(id="submit", title="Submit")], - rendered_content="

hello

", - expiration_time=datetime.utcnow() + timedelta(hours=1), - ), - rendered_content="

hello

", - created_at=datetime.utcnow(), - expiration_time=datetime.utcnow() + timedelta(hours=1), - status=HumanInputFormStatus.WAITING, - selected_action_id=None, - submitted_data=None, - submitted_at=None, - submission_user_id=None, - submission_end_user_id=None, - completed_by_recipient_id=None, - recipient_id="recipient-id", - recipient_type=RecipientType.STANDALONE_WEB_APP, - access_token="token", - ) - - -def test_enqueue_resume_dispatches_task_for_workflow(mocker, mock_session_factory): - session_factory, session = mock_session_factory - service = HumanInputService(session_factory) - - workflow_run = MagicMock() - workflow_run.app_id = "app-id" - - workflow_run_repo = MagicMock() - workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run - mocker.patch( - "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", - return_value=workflow_run_repo, - ) - - app = MagicMock() - app.mode = "workflow" - session.execute.return_value.scalar_one_or_none.return_value = app - - resume_task = mocker.patch("services.human_input_service.resume_app_execution") - - service.enqueue_resume("workflow-run-id") - - resume_task.apply_async.assert_called_once() - call_kwargs = resume_task.apply_async.call_args.kwargs - assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE - assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id" - - -def test_ensure_form_active_respects_global_timeout(monkeypatch, sample_form_record, mock_session_factory): - session_factory, _ = mock_session_factory - service = HumanInputService(session_factory) - expired_record = dataclasses.replace( - sample_form_record, - created_at=datetime.utcnow() - timedelta(hours=2), - expiration_time=datetime.utcnow() + timedelta(hours=2), - ) - monkeypatch.setattr(human_input_service_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600) - - with pytest.raises(FormExpiredError): - service.ensure_form_active(Form(expired_record)) - - -def test_enqueue_resume_dispatches_task_for_advanced_chat(mocker, mock_session_factory): - session_factory, session = mock_session_factory - service = HumanInputService(session_factory) - - workflow_run = MagicMock() - workflow_run.app_id = "app-id" - - workflow_run_repo = MagicMock() - workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run - mocker.patch( - "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", - return_value=workflow_run_repo, - ) - - app = MagicMock() - app.mode = "advanced-chat" - session.execute.return_value.scalar_one_or_none.return_value = app - - resume_task = mocker.patch("services.human_input_service.resume_app_execution") - - service.enqueue_resume("workflow-run-id") - - resume_task.apply_async.assert_called_once() - call_kwargs = resume_task.apply_async.call_args.kwargs - assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE - assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id" - - -def test_enqueue_resume_skips_unsupported_app_mode(mocker, mock_session_factory): - session_factory, session = mock_session_factory - service = HumanInputService(session_factory) - - workflow_run = MagicMock() - workflow_run.app_id = "app-id" - - workflow_run_repo = MagicMock() - workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run - mocker.patch( - "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", - return_value=workflow_run_repo, - ) - - app = MagicMock() - app.mode = "completion" - session.execute.return_value.scalar_one_or_none.return_value = app - - resume_task = mocker.patch("services.human_input_service.resume_app_execution") - - service.enqueue_resume("workflow-run-id") - - resume_task.apply_async.assert_not_called() - - -def test_get_form_definition_by_token_for_console_uses_repository(sample_form_record, mock_session_factory): - session_factory, _ = mock_session_factory - repo = MagicMock(spec=HumanInputFormSubmissionRepository) - console_record = dataclasses.replace(sample_form_record, recipient_type=RecipientType.CONSOLE) - repo.get_by_token.return_value = console_record - - service = HumanInputService(session_factory, form_repository=repo) - form = service.get_form_definition_by_token_for_console("token") - - repo.get_by_token.assert_called_once_with("token") - assert form is not None - assert form.get_definition() == console_record.definition - - -def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, mock_session_factory, mocker): - session_factory, _ = mock_session_factory - repo = MagicMock(spec=HumanInputFormSubmissionRepository) - repo.get_by_token.return_value = sample_form_record - repo.mark_submitted.return_value = sample_form_record - service = HumanInputService(session_factory, form_repository=repo) - enqueue_spy = mocker.patch.object(service, "enqueue_resume") - - service.submit_form_by_token( - recipient_type=RecipientType.STANDALONE_WEB_APP, - form_token="token", - selected_action_id="submit", - form_data={"field": "value"}, - submission_end_user_id="end-user-id", - ) - - repo.get_by_token.assert_called_once_with("token") - repo.mark_submitted.assert_called_once() - call_kwargs = repo.mark_submitted.call_args.kwargs - assert call_kwargs["form_id"] == sample_form_record.form_id - assert call_kwargs["recipient_id"] == sample_form_record.recipient_id - assert call_kwargs["selected_action_id"] == "submit" - assert call_kwargs["form_data"] == {"field": "value"} - assert call_kwargs["submission_end_user_id"] == "end-user-id" - enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id) - - -def test_submit_form_by_token_skips_enqueue_for_delivery_test(sample_form_record, mock_session_factory, mocker): - session_factory, _ = mock_session_factory - repo = MagicMock(spec=HumanInputFormSubmissionRepository) - test_record = dataclasses.replace( - sample_form_record, - form_kind=HumanInputFormKind.DELIVERY_TEST, - workflow_run_id=None, - ) - repo.get_by_token.return_value = test_record - repo.mark_submitted.return_value = test_record - service = HumanInputService(session_factory, form_repository=repo) - enqueue_spy = mocker.patch.object(service, "enqueue_resume") - - service.submit_form_by_token( - recipient_type=RecipientType.STANDALONE_WEB_APP, - form_token="token", - selected_action_id="submit", - form_data={"field": "value"}, - ) - - enqueue_spy.assert_not_called() - - -def test_submit_form_by_token_passes_submission_user_id(sample_form_record, mock_session_factory, mocker): - session_factory, _ = mock_session_factory - repo = MagicMock(spec=HumanInputFormSubmissionRepository) - repo.get_by_token.return_value = sample_form_record - repo.mark_submitted.return_value = sample_form_record - service = HumanInputService(session_factory, form_repository=repo) - enqueue_spy = mocker.patch.object(service, "enqueue_resume") - - service.submit_form_by_token( - recipient_type=RecipientType.STANDALONE_WEB_APP, - form_token="token", - selected_action_id="submit", - form_data={"field": "value"}, - submission_user_id="account-id", - ) - - call_kwargs = repo.mark_submitted.call_args.kwargs - assert call_kwargs["submission_user_id"] == "account-id" - assert call_kwargs["submission_end_user_id"] is None - enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id) - - -def test_submit_form_by_token_invalid_action(sample_form_record, mock_session_factory): - session_factory, _ = mock_session_factory - repo = MagicMock(spec=HumanInputFormSubmissionRepository) - repo.get_by_token.return_value = dataclasses.replace(sample_form_record) - service = HumanInputService(session_factory, form_repository=repo) - - with pytest.raises(InvalidFormDataError) as exc_info: - service.submit_form_by_token( - recipient_type=RecipientType.STANDALONE_WEB_APP, - form_token="token", - selected_action_id="invalid", - form_data={}, - ) - - assert "Invalid action" in str(exc_info.value) - repo.mark_submitted.assert_not_called() - - -def test_submit_form_by_token_missing_inputs(sample_form_record, mock_session_factory): - session_factory, _ = mock_session_factory - repo = MagicMock(spec=HumanInputFormSubmissionRepository) - - definition_with_input = FormDefinition( - form_content="hello", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content")], - user_actions=sample_form_record.definition.user_actions, - rendered_content="

hello

", - expiration_time=sample_form_record.expiration_time, - ) - form_with_input = dataclasses.replace(sample_form_record, definition=definition_with_input) - repo.get_by_token.return_value = form_with_input - service = HumanInputService(session_factory, form_repository=repo) - - with pytest.raises(InvalidFormDataError) as exc_info: - service.submit_form_by_token( - recipient_type=RecipientType.STANDALONE_WEB_APP, - form_token="token", - selected_action_id="submit", - form_data={}, - ) - - assert "Missing required inputs" in str(exc_info.value) - repo.mark_submitted.assert_not_called() diff --git a/api/tests/unit_tests/services/test_message_service_extra_contents.py b/api/tests/unit_tests/services/test_message_service_extra_contents.py deleted file mode 100644 index 3c8e301caa..0000000000 --- a/api/tests/unit_tests/services/test_message_service_extra_contents.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import annotations - -import pytest - -from core.entities.execution_extra_content import HumanInputContent, HumanInputFormSubmissionData -from services import message_service - - -class _FakeMessage: - def __init__(self, message_id: str): - self.id = message_id - self.extra_contents = None - - def set_extra_contents(self, contents): - self.extra_contents = contents - - -def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: pytest.MonkeyPatch) -> None: - messages = [_FakeMessage("msg-1"), _FakeMessage("msg-2")] - repo = type( - "Repo", - (), - { - "get_by_message_ids": lambda _self, message_ids: [ - [ - HumanInputContent( - workflow_run_id="workflow-run-1", - submitted=True, - form_submission_data=HumanInputFormSubmissionData( - node_id="node-1", - node_title="Approval", - rendered_content="Rendered", - action_id="approve", - action_text="Approve", - ), - ) - ], - [], - ] - }, - )() - - monkeypatch.setattr(message_service, "_create_execution_extra_content_repository", lambda: repo) - - message_service.attach_message_extra_contents(messages) - - assert messages[0].extra_contents == [ - { - "type": "human_input", - "workflow_run_id": "workflow-run-1", - "submitted": True, - "form_submission_data": { - "node_id": "node-1", - "node_title": "Approval", - "rendered_content": "Rendered", - "action_id": "approve", - "action_text": "Approve", - }, - } - ] - assert messages[1].extra_contents == [] diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index ded141f01a..f45a72927e 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -35,6 +35,7 @@ class TestDataFactory: app_id: str = "app-789", workflow_id: str = "workflow-101", status: str | WorkflowExecutionStatus = "paused", + pause_id: str | None = None, **kwargs, ) -> MagicMock: """Create a mock WorkflowRun object.""" @@ -44,6 +45,7 @@ class TestDataFactory: mock_run.app_id = app_id mock_run.workflow_id = workflow_id mock_run.status = status + mock_run.pause_id = pause_id for key, value in kwargs.items(): setattr(mock_run, key, value) diff --git a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py deleted file mode 100644 index d6c92f1013..0000000000 --- a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py +++ /dev/null @@ -1,158 +0,0 @@ -import json -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from models.model import App -from models.tools import WorkflowToolProvider -from services.tools import workflow_tools_manage_service - - -class DummyWorkflow: - def __init__(self, graph_dict: dict, version: str = "1.0.0") -> None: - self._graph_dict = graph_dict - self.version = version - - @property - def graph_dict(self) -> dict: - return self._graph_dict - - -class FakeQuery: - def __init__(self, result): - self._result = result - - def where(self, *args, **kwargs): - return self - - def first(self): - return self._result - - -class DummySession: - def __init__(self) -> None: - self.added: list[object] = [] - - def __enter__(self) -> "DummySession": - return self - - def __exit__(self, exc_type, exc, tb) -> bool: - return False - - def add(self, obj) -> None: - self.added.append(obj) - - def begin(self): - return DummyBegin(self) - - -class DummyBegin: - def __init__(self, session: DummySession) -> None: - self._session = session - - def __enter__(self) -> DummySession: - return self._session - - def __exit__(self, exc_type, exc, tb) -> bool: - return False - - -class DummySessionContext: - def __init__(self, session: DummySession) -> None: - self._session = session - - def __enter__(self) -> DummySession: - return self._session - - def __exit__(self, exc_type, exc, tb) -> bool: - return False - - -class DummySessionFactory: - def __init__(self, session: DummySession) -> None: - self._session = session - - def create_session(self) -> DummySessionContext: - return DummySessionContext(self._session) - - -def _build_fake_session(app) -> SimpleNamespace: - def query(model): - if model is WorkflowToolProvider: - return FakeQuery(None) - if model is App: - return FakeQuery(app) - return FakeQuery(None) - - return SimpleNamespace(query=query) - - -def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch): - workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "human-input"}}]}) - app = SimpleNamespace(workflow=workflow) - - fake_session = _build_fake_session(app) - monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) - - mock_from_db = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) - mock_invalidate = MagicMock() - - parameters = [{"name": "input", "description": "input", "form": "form"}] - - with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: - workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool( - user_id="user-id", - tenant_id="tenant-id", - workflow_app_id="app-id", - name="tool_name", - label="Tool", - icon={"type": "emoji", "emoji": "tool"}, - description="desc", - parameters=parameters, - ) - - assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - mock_from_db.assert_not_called() - mock_invalidate.assert_not_called() - - -def test_create_workflow_tool_success(monkeypatch): - workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "start"}}]}) - app = SimpleNamespace(workflow=workflow) - - fake_db = MagicMock() - fake_session = _build_fake_session(app) - fake_db.session = fake_session - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - - dummy_session = DummySession() - monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) - - mock_from_db = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) - - parameters = [{"name": "input", "description": "input", "form": "form"}] - icon = {"type": "emoji", "emoji": "tool"} - - result = workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool( - user_id="user-id", - tenant_id="tenant-id", - workflow_app_id="app-id", - name="tool_name", - label="Tool", - icon=icon, - description="desc", - parameters=parameters, - ) - - assert result == {"result": "success"} - assert len(dummy_session.added) == 1 - created_provider = dummy_session.added[0] - assert created_provider.name == "tool_name" - assert created_provider.label == "Tool" - assert created_provider.icon == json.dumps(icon) - assert created_provider.version == workflow.version - mock_from_db.assert_called_once() diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py deleted file mode 100644 index 844dab8976..0000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py +++ /dev/null @@ -1,226 +0,0 @@ -from __future__ import annotations - -import json -import queue -from collections.abc import Sequence -from dataclasses import dataclass -from datetime import UTC, datetime -from threading import Event - -import pytest - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from core.workflow.runtime import GraphRuntimeState, VariablePool -from models.enums import CreatorUserRole -from models.model import AppMode -from models.workflow import WorkflowRun -from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot -from repositories.entities.workflow_pause import WorkflowPauseEntity -from services.workflow_event_snapshot_service import ( - BufferState, - MessageContext, - _build_snapshot_events, - _resolve_task_id, -) - - -@dataclass(frozen=True) -class _FakePauseEntity(WorkflowPauseEntity): - pause_id: str - workflow_run_id: str - paused_at_value: datetime - pause_reasons: Sequence[HumanInputRequired] - - @property - def id(self) -> str: - return self.pause_id - - @property - def workflow_execution_id(self) -> str: - return self.workflow_run_id - - def get_state(self) -> bytes: - raise AssertionError("state is not required for snapshot tests") - - @property - def resumed_at(self) -> datetime | None: - return None - - @property - def paused_at(self) -> datetime: - return self.paused_at_value - - def get_pause_reasons(self) -> Sequence[HumanInputRequired]: - return self.pause_reasons - - -def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun: - return WorkflowRun( - id="run-1", - tenant_id="tenant-1", - app_id="app-1", - workflow_id="workflow-1", - type="workflow", - triggered_from="app-run", - version="v1", - graph=None, - inputs=json.dumps({"input": "value"}), - status=status, - outputs=json.dumps({}), - error=None, - elapsed_time=0.0, - total_tokens=0, - total_steps=0, - created_by_role=CreatorUserRole.END_USER, - created_by="user-1", - created_at=datetime(2024, 1, 1, tzinfo=UTC), - ) - - -def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot: - created_at = datetime(2024, 1, 1, tzinfo=UTC) - finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC) - return WorkflowNodeExecutionSnapshot( - execution_id="exec-1", - node_id="node-1", - node_type="human-input", - title="Human Input", - index=1, - status=status.value, - elapsed_time=0.5, - created_at=created_at, - finished_at=finished_at, - iteration_id=None, - loop_id=None, - ) - - -def _build_resumption_context(task_id: str) -> WorkflowResumptionContext: - app_config = WorkflowUIBasedAppConfig( - tenant_id="tenant-1", - app_id="app-1", - app_mode=AppMode.WORKFLOW, - workflow_id="workflow-1", - ) - generate_entity = WorkflowAppGenerateEntity( - task_id=task_id, - app_config=app_config, - inputs={}, - files=[], - user_id="user-1", - stream=True, - invoke_from=InvokeFrom.EXPLORE, - call_depth=0, - workflow_execution_id="run-1", - ) - runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) - runtime_state.register_paused_node("node-1") - runtime_state.outputs = {"result": "value"} - wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity) - return WorkflowResumptionContext( - generate_entity=wrapper, - serialized_graph_runtime_state=runtime_state.dumps(), - ) - - -def test_build_snapshot_events_includes_pause_event() -> None: - workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED) - snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED) - resumption_context = _build_resumption_context("task-ctx") - pause_entity = _FakePauseEntity( - pause_id="pause-1", - workflow_run_id="run-1", - paused_at_value=datetime(2024, 1, 1, tzinfo=UTC), - pause_reasons=[ - HumanInputRequired( - form_id="form-1", - form_content="content", - node_id="node-1", - node_title="Human Input", - ) - ], - ) - - events = _build_snapshot_events( - workflow_run=workflow_run, - node_snapshots=[snapshot], - task_id="task-ctx", - message_context=None, - pause_entity=pause_entity, - resumption_context=resumption_context, - ) - - assert [event["event"] for event in events] == [ - "workflow_started", - "node_started", - "node_finished", - "workflow_paused", - ] - assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value - pause_data = events[-1]["data"] - assert pause_data["paused_nodes"] == ["node-1"] - assert pause_data["outputs"] == {"result": "value"} - assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value - assert pause_data["created_at"] == int(workflow_run.created_at.timestamp()) - assert pause_data["elapsed_time"] == workflow_run.elapsed_time - assert pause_data["total_tokens"] == workflow_run.total_tokens - assert pause_data["total_steps"] == workflow_run.total_steps - - -def test_build_snapshot_events_applies_message_context() -> None: - workflow_run = _build_workflow_run(WorkflowExecutionStatus.RUNNING) - snapshot = _build_snapshot(WorkflowNodeExecutionStatus.SUCCEEDED) - message_context = MessageContext( - conversation_id="conv-1", - message_id="msg-1", - created_at=1700000000, - answer="snapshot message", - ) - - events = _build_snapshot_events( - workflow_run=workflow_run, - node_snapshots=[snapshot], - task_id="task-1", - message_context=message_context, - pause_entity=None, - resumption_context=None, - ) - - assert [event["event"] for event in events] == [ - "workflow_started", - "message_replace", - "node_started", - "node_finished", - ] - assert events[1]["answer"] == "snapshot message" - for event in events: - assert event["conversation_id"] == "conv-1" - assert event["message_id"] == "msg-1" - assert event["created_at"] == 1700000000 - - -@pytest.mark.parametrize( - ("context_task_id", "buffered_task_id", "expected"), - [ - ("task-ctx", "task-buffer", "task-ctx"), - (None, "task-buffer", "task-buffer"), - (None, None, "run-1"), - ], -) -def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -> None: - resumption_context = _build_resumption_context(context_task_id) if context_task_id else None - buffer_state = BufferState( - queue=queue.Queue(), - stop_event=Event(), - done_event=Event(), - task_id_ready=Event(), - task_id_hint=buffered_task_id, - ) - if buffered_task_id: - buffer_state.task_id_ready.set() - task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0) - assert task_id == expected diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py deleted file mode 100644 index 5ac5ac8ad2..0000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py +++ /dev/null @@ -1,184 +0,0 @@ -import uuid -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest -from sqlalchemy.orm import sessionmaker - -from core.workflow.enums import NodeType -from core.workflow.nodes.human_input.entities import ( - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, - HumanInputNodeData, - MemberRecipient, -) -from services import workflow_service as workflow_service_module -from services.workflow_service import WorkflowService - - -def _make_service() -> WorkflowService: - return WorkflowService(session_maker=sessionmaker()) - - -def _build_node_config(delivery_methods): - node_data = HumanInputNodeData( - title="Human Input", - delivery_methods=delivery_methods, - form_content="Test content", - inputs=[], - user_actions=[], - ).model_dump(mode="json") - node_data["type"] = NodeType.HUMAN_INPUT.value - return {"id": "node-1", "data": node_data} - - -def _make_email_method(enabled: bool = True, debug_mode: bool = False) -> EmailDeliveryMethod: - return EmailDeliveryMethod( - id=uuid.uuid4(), - enabled=enabled, - config=EmailDeliveryConfig( - recipients=EmailRecipients( - whole_workspace=False, - items=[ExternalRecipient(email="tester@example.com")], - ), - subject="Test subject", - body="Test body", - debug_mode=debug_mode, - ), - ) - - -def test_human_input_delivery_requires_draft_workflow(): - service = _make_service() - service.get_draft_workflow = MagicMock(return_value=None) # type: ignore[method-assign] - app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") - account = SimpleNamespace(id="account-1") - - with pytest.raises(ValueError, match="Workflow not initialized"): - service.test_human_input_delivery( - app_model=app_model, - account=account, - node_id="node-1", - delivery_method_id="delivery-1", - ) - - -def test_human_input_delivery_allows_disabled_method(monkeypatch: pytest.MonkeyPatch): - service = _make_service() - delivery_method = _make_email_method(enabled=False) - node_config = _build_node_config([delivery_method]) - workflow = MagicMock() - workflow.get_node_config_by_id.return_value = node_config - service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] - service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined] - node_stub = MagicMock() - node_stub._render_form_content_before_submission.return_value = "rendered" - node_stub._resolve_default_values.return_value = {} - service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined] - service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined] - return_value=("form-1", {}) - ) - - test_service_instance = MagicMock() - monkeypatch.setattr( - workflow_service_module, - "HumanInputDeliveryTestService", - MagicMock(return_value=test_service_instance), - ) - - app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") - account = SimpleNamespace(id="account-1") - - service.test_human_input_delivery( - app_model=app_model, - account=account, - node_id="node-1", - delivery_method_id=str(delivery_method.id), - ) - - test_service_instance.send_test.assert_called_once() - - -def test_human_input_delivery_dispatches_to_test_service(monkeypatch: pytest.MonkeyPatch): - service = _make_service() - delivery_method = _make_email_method(enabled=True) - node_config = _build_node_config([delivery_method]) - workflow = MagicMock() - workflow.get_node_config_by_id.return_value = node_config - service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] - service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined] - node_stub = MagicMock() - node_stub._render_form_content_before_submission.return_value = "rendered" - node_stub._resolve_default_values.return_value = {} - service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined] - service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined] - return_value=("form-1", {}) - ) - - test_service_instance = MagicMock() - monkeypatch.setattr( - workflow_service_module, - "HumanInputDeliveryTestService", - MagicMock(return_value=test_service_instance), - ) - - app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") - account = SimpleNamespace(id="account-1") - - service.test_human_input_delivery( - app_model=app_model, - account=account, - node_id="node-1", - delivery_method_id=str(delivery_method.id), - inputs={"#node-1.output#": "value"}, - ) - - pool_args = service._build_human_input_variable_pool.call_args.kwargs - assert pool_args["manual_inputs"] == {"#node-1.output#": "value"} - test_service_instance.send_test.assert_called_once() - - -def test_human_input_delivery_debug_mode_overrides_recipients(monkeypatch: pytest.MonkeyPatch): - service = _make_service() - delivery_method = _make_email_method(enabled=True, debug_mode=True) - node_config = _build_node_config([delivery_method]) - workflow = MagicMock() - workflow.get_node_config_by_id.return_value = node_config - service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] - service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined] - node_stub = MagicMock() - node_stub._render_form_content_before_submission.return_value = "rendered" - node_stub._resolve_default_values.return_value = {} - service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined] - service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined] - return_value=("form-1", {}) - ) - - test_service_instance = MagicMock() - monkeypatch.setattr( - workflow_service_module, - "HumanInputDeliveryTestService", - MagicMock(return_value=test_service_instance), - ) - - app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") - account = SimpleNamespace(id="account-1") - - service.test_human_input_delivery( - app_model=app_model, - account=account, - node_id="node-1", - delivery_method_id=str(delivery_method.id), - ) - - test_service_instance.send_test.assert_called_once() - sent_method = test_service_instance.send_test.call_args.kwargs["method"] - assert isinstance(sent_method, EmailDeliveryMethod) - assert sent_method.config.debug_mode is True - assert sent_method.config.recipients.whole_workspace is False - assert len(sent_method.config.recipients.items) == 1 - recipient = sent_method.config.recipients.items[0] - assert isinstance(recipient, MemberRecipient) - assert recipient.user_id == account.id diff --git a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py index 70d7bde870..32d2f8b7e0 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -5,7 +5,6 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session -from core.workflow.enums import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel from repositories.sqlalchemy_api_workflow_node_execution_repository import ( DifyAPISQLAlchemyWorkflowNodeExecutionRepository, @@ -53,9 +52,6 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: call_args = mock_session.scalar.call_args[0][0] assert hasattr(call_args, "compile") # It's a SQLAlchemy statement - compiled = call_args.compile() - assert WorkflowNodeExecutionStatus.PAUSED in compiled.params.values() - def test_get_node_last_execution_not_found(self, repository): """Test getting the last execution for a node when it doesn't exist.""" # Arrange @@ -75,6 +71,28 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert result is None mock_session.scalar.assert_called_once() + def test_get_executions_by_workflow_run(self, repository, mock_execution): + """Test getting all executions for a workflow run.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + executions = [mock_execution] + mock_session.execute.return_value.scalars.return_value.all.return_value = executions + + # Act + result = repository.get_executions_by_workflow_run( + tenant_id="tenant-123", + app_id="app-456", + workflow_run_id="run-101", + ) + + # Assert + assert result == executions + mock_session.execute.assert_called_once() + # Verify the query was constructed correctly + call_args = mock_session.execute.call_args[0][0] + assert hasattr(call_args, "compile") # It's a SQLAlchemy statement + def test_get_executions_by_workflow_run_empty(self, repository): """Test getting executions for a workflow run when none exist.""" # Arrange diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 015dac257e..9700cbaf0e 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -1,15 +1,9 @@ -from contextlib import nullcontext -from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from core.workflow.enums import NodeType -from core.workflow.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import FormInputType from models.model import App from models.workflow import Workflow -from services import workflow_service as workflow_service_module from services.workflow_service import WorkflowService @@ -167,120 +161,3 @@ class TestWorkflowService: assert workflows == [] assert has_more is False mock_session.scalars.assert_called_once() - - def test_submit_human_input_form_preview_uses_rendered_content( - self, workflow_service: WorkflowService, monkeypatch: pytest.MonkeyPatch - ) -> None: - service = workflow_service - node_data = HumanInputNodeData( - title="Human Input", - form_content="

{{#$output.name#}}

", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], - user_actions=[UserAction(id="approve", title="Approve")], - ) - node = MagicMock() - node.node_data = node_data - node.render_form_content_before_submission.return_value = "

preview

" - node.render_form_content_with_outputs.return_value = "

rendered

" - - service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] - service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] - - workflow = MagicMock() - workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} - workflow.get_enclosing_node_type_and_id.return_value = None - service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] - - saved_outputs: dict[str, object] = {} - - class DummySession: - def __init__(self, *args, **kwargs): - self.commit = MagicMock() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def begin(self): - return nullcontext() - - class DummySaver: - def __init__(self, *args, **kwargs): - pass - - def save(self, outputs, process_data): - saved_outputs.update(outputs) - - monkeypatch.setattr(workflow_service_module, "Session", DummySession) - monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver) - monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) - - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") - account = SimpleNamespace(id="account-1") - - result = service.submit_human_input_form_preview( - app_model=app_model, - account=account, - node_id="node-1", - form_inputs={"name": "Ada", "extra": "ignored"}, - inputs={"#node-0.result#": "LLM output"}, - action="approve", - ) - - service._build_human_input_variable_pool.assert_called_once_with( - app_model=app_model, - workflow=workflow, - node_config={"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}, - manual_inputs={"#node-0.result#": "LLM output"}, - ) - - node.render_form_content_with_outputs.assert_called_once() - called_args = node.render_form_content_with_outputs.call_args.args - assert called_args[0] == "

preview

" - assert called_args[2] == node_data.outputs_field_names() - rendered_outputs = called_args[1] - assert rendered_outputs["name"] == "Ada" - assert rendered_outputs["extra"] == "ignored" - assert "extra" in saved_outputs - assert "extra" in result - assert saved_outputs["name"] == "Ada" - assert result["name"] == "Ada" - assert result["__action_id"] == "approve" - assert "__rendered_content" in result - - def test_submit_human_input_form_preview_missing_inputs_message(self, workflow_service: WorkflowService) -> None: - service = workflow_service - node_data = HumanInputNodeData( - title="Human Input", - form_content="

{{#$output.name#}}

", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], - user_actions=[UserAction(id="approve", title="Approve")], - ) - node = MagicMock() - node.node_data = node_data - node._render_form_content_before_submission.return_value = "

preview

" - node._render_form_content_with_outputs.return_value = "

rendered

" - - service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] - service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] - - workflow = MagicMock() - workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} - service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] - - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") - account = SimpleNamespace(id="account-1") - - with pytest.raises(ValueError) as exc_info: - service.submit_human_input_form_preview( - app_model=app_model, - account=account, - node_id="node-1", - form_inputs={}, - inputs={}, - action="approve", - ) - - assert "Missing required inputs" in str(exc_info.value) diff --git a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py deleted file mode 100644 index ee0699ba2d..0000000000 --- a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py +++ /dev/null @@ -1,210 +0,0 @@ -from __future__ import annotations - -from datetime import datetime, timedelta -from types import SimpleNamespace -from typing import Any - -import pytest - -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus -from tasks import human_input_timeout_tasks as task_module - - -class _FakeScalarResult: - def __init__(self, items: list[Any]): - self._items = items - - def all(self) -> list[Any]: - return self._items - - -class _FakeSession: - def __init__(self, items: list[Any], capture: dict[str, Any]): - self._items = items - self._capture = capture - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def scalars(self, stmt): - self._capture["stmt"] = stmt - return _FakeScalarResult(self._items) - - -class _FakeSessionFactory: - def __init__(self, items: list[Any], capture: dict[str, Any]): - self._items = items - self._capture = capture - self._capture["session_factory"] = self - - def __call__(self): - session = _FakeSession(self._items, self._capture) - self._capture["session"] = session - return session - - -class _FakeFormRepo: - def __init__(self, _session_factory, form_map: dict[str, Any] | None = None): - self.calls: list[dict[str, Any]] = [] - self._form_map = form_map or {} - - def mark_timeout(self, *, form_id: str, timeout_status: HumanInputFormStatus, reason: str | None = None): - self.calls.append( - { - "form_id": form_id, - "timeout_status": timeout_status, - "reason": reason, - } - ) - form = self._form_map.get(form_id) - return SimpleNamespace( - form_id=form_id, - workflow_run_id=getattr(form, "workflow_run_id", None), - node_id=getattr(form, "node_id", None), - ) - - -class _FakeService: - def __init__(self, _session_factory, form_repository=None): - self.enqueued: list[str] = [] - - def enqueue_resume(self, workflow_run_id: str | None) -> None: - if workflow_run_id is not None: - self.enqueued.append(workflow_run_id) - - -def _build_form( - *, - form_id: str, - form_kind: HumanInputFormKind, - created_at: datetime, - expiration_time: datetime, - workflow_run_id: str | None, - node_id: str, -) -> SimpleNamespace: - return SimpleNamespace( - id=form_id, - form_kind=form_kind, - created_at=created_at, - expiration_time=expiration_time, - workflow_run_id=workflow_run_id, - node_id=node_id, - status=HumanInputFormStatus.WAITING, - ) - - -def test_is_global_timeout_uses_created_at(): - now = datetime(2025, 1, 1, 12, 0, 0) - form = SimpleNamespace(created_at=now - timedelta(seconds=61), workflow_run_id="run-1") - - assert task_module._is_global_timeout(form, 60, now=now) is True - - form.workflow_run_id = None - assert task_module._is_global_timeout(form, 60, now=now) is False - - form.workflow_run_id = "run-1" - form.created_at = now - timedelta(seconds=59) - assert task_module._is_global_timeout(form, 60, now=now) is False - - assert task_module._is_global_timeout(form, 0, now=now) is False - - -def test_check_and_handle_human_input_timeouts_marks_and_routes(monkeypatch: pytest.MonkeyPatch): - now = datetime(2025, 1, 1, 12, 0, 0) - monkeypatch.setattr(task_module, "naive_utc_now", lambda: now) - monkeypatch.setattr(task_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600) - monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object())) - - forms = [ - _build_form( - form_id="form-global", - form_kind=HumanInputFormKind.RUNTIME, - created_at=now - timedelta(hours=2), - expiration_time=now + timedelta(hours=1), - workflow_run_id="run-global", - node_id="node-global", - ), - _build_form( - form_id="form-node", - form_kind=HumanInputFormKind.RUNTIME, - created_at=now - timedelta(minutes=5), - expiration_time=now - timedelta(seconds=1), - workflow_run_id="run-node", - node_id="node-node", - ), - _build_form( - form_id="form-delivery", - form_kind=HumanInputFormKind.DELIVERY_TEST, - created_at=now - timedelta(minutes=1), - expiration_time=now - timedelta(seconds=1), - workflow_run_id=None, - node_id="node-delivery", - ), - ] - - capture: dict[str, Any] = {} - monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory(forms, capture)) - - form_map = {form.id: form for form in forms} - repo = _FakeFormRepo(None, form_map=form_map) - - def _repo_factory(_session_factory): - return repo - - service = _FakeService(None) - - def _service_factory(_session_factory, form_repository=None): - return service - - global_calls: list[dict[str, Any]] = [] - - monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _repo_factory) - monkeypatch.setattr(task_module, "HumanInputService", _service_factory) - monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **kwargs: global_calls.append(kwargs)) - - task_module.check_and_handle_human_input_timeouts(limit=100) - - assert {(call["form_id"], call["timeout_status"], call["reason"]) for call in repo.calls} == { - ("form-global", HumanInputFormStatus.EXPIRED, "global_timeout"), - ("form-node", HumanInputFormStatus.TIMEOUT, "node_timeout"), - ("form-delivery", HumanInputFormStatus.TIMEOUT, "delivery_test_timeout"), - } - assert service.enqueued == ["run-node"] - assert global_calls == [ - { - "form_id": "form-global", - "workflow_run_id": "run-global", - "node_id": "node-global", - "session_factory": capture.get("session_factory"), - } - ] - - stmt = capture.get("stmt") - assert stmt is not None - stmt_text = str(stmt) - assert "created_at <=" in stmt_text - assert "expiration_time <=" in stmt_text - assert "ORDER BY human_input_forms.id" in stmt_text - - -def test_check_and_handle_human_input_timeouts_omits_global_filter_when_disabled(monkeypatch: pytest.MonkeyPatch): - now = datetime(2025, 1, 1, 12, 0, 0) - monkeypatch.setattr(task_module, "naive_utc_now", lambda: now) - monkeypatch.setattr(task_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 0) - monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object())) - - capture: dict[str, Any] = {} - monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory([], capture)) - monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _FakeFormRepo) - monkeypatch.setattr(task_module, "HumanInputService", _FakeService) - monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **_kwargs: None) - - task_module.check_and_handle_human_input_timeouts(limit=1) - - stmt = capture.get("stmt") - assert stmt is not None - stmt_text = str(stmt) - assert "created_at <=" not in stmt_text diff --git a/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py deleted file mode 100644 index 20cb7a211e..0000000000 --- a/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py +++ /dev/null @@ -1,123 +0,0 @@ -from collections.abc import Sequence -from types import SimpleNamespace - -import pytest - -from tasks import mail_human_input_delivery_task as task_module - - -class _DummyMail: - def __init__(self): - self.sent: list[dict[str, str]] = [] - self._inited = True - - def is_inited(self) -> bool: - return self._inited - - def send(self, *, to: str, subject: str, html: str): - self.sent.append({"to": to, "subject": subject, "html": html}) - - -class _DummySession: - def __init__(self, form): - self._form = form - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - return False - - def get(self, _model, _form_id): - return self._form - - -def _build_job(recipient_count: int = 1) -> task_module._EmailDeliveryJob: - recipients: list[task_module._EmailRecipient] = [] - for idx in range(recipient_count): - recipients.append(task_module._EmailRecipient(email=f"user{idx}@example.com", token=f"token-{idx}")) - - return task_module._EmailDeliveryJob( - form_id="form-1", - subject="Subject", - body="Body for {{#url}}", - form_content="content", - recipients=recipients, - ) - - -def test_dispatch_human_input_email_task_sends_to_each_recipient(monkeypatch: pytest.MonkeyPatch): - mail = _DummyMail() - form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id=None) - - monkeypatch.setattr(task_module, "mail", mail) - monkeypatch.setattr( - task_module.FeatureService, - "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), - ) - jobs: Sequence[task_module._EmailDeliveryJob] = [_build_job(recipient_count=2)] - monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: jobs) - - task_module.dispatch_human_input_email_task( - form_id="form-1", - node_title="Approve", - session_factory=lambda: _DummySession(form), - ) - - assert len(mail.sent) == 2 - assert all(payload["subject"] == "Subject" for payload in mail.sent) - assert all("Body for" in payload["html"] for payload in mail.sent) - - -def test_dispatch_human_input_email_task_skips_when_feature_disabled(monkeypatch: pytest.MonkeyPatch): - mail = _DummyMail() - form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id=None) - - monkeypatch.setattr(task_module, "mail", mail) - monkeypatch.setattr( - task_module.FeatureService, - "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False), - ) - monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: []) - - task_module.dispatch_human_input_email_task( - form_id="form-1", - node_title="Approve", - session_factory=lambda: _DummySession(form), - ) - - assert mail.sent == [] - - -def test_dispatch_human_input_email_task_replaces_body_variables(monkeypatch: pytest.MonkeyPatch): - mail = _DummyMail() - form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id="run-1") - job = task_module._EmailDeliveryJob( - form_id="form-1", - subject="Subject", - body="Body {{#node1.value#}}", - form_content="content", - recipients=[task_module._EmailRecipient(email="user@example.com", token="token-1")], - ) - - variable_pool = task_module.VariablePool() - variable_pool.add(["node1", "value"], "OK") - - monkeypatch.setattr(task_module, "mail", mail) - monkeypatch.setattr( - task_module.FeatureService, - "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), - ) - monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: [job]) - monkeypatch.setattr(task_module, "_load_variable_pool", lambda _workflow_run_id: variable_pool) - - task_module.dispatch_human_input_email_task( - form_id="form-1", - node_title="Approve", - session_factory=lambda: _DummySession(form), - ) - - assert mail.sent[0]["html"] == "Body OK" diff --git a/api/tests/unit_tests/tasks/test_workflow_execute_task.py b/api/tests/unit_tests/tasks/test_workflow_execute_task.py deleted file mode 100644 index 161151305d..0000000000 --- a/api/tests/unit_tests/tasks/test_workflow_execute_task.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - -import json -import uuid -from unittest.mock import MagicMock - -import pytest - -from models.model import AppMode -from tasks.app_generate.workflow_execute_task import _publish_streaming_response - - -@pytest.fixture -def mock_topic(mocker) -> MagicMock: - topic = MagicMock() - mocker.patch( - "tasks.app_generate.workflow_execute_task.MessageBasedAppGenerator.get_response_topic", - return_value=topic, - ) - return topic - - -def test_publish_streaming_response_with_uuid(mock_topic: MagicMock): - workflow_run_id = uuid.uuid4() - response_stream = iter([{"event": "foo"}, "ping"]) - - _publish_streaming_response(response_stream, workflow_run_id, app_mode=AppMode.ADVANCED_CHAT) - - payloads = [call.args[0] for call in mock_topic.publish.call_args_list] - assert payloads == [json.dumps({"event": "foo"}).encode(), json.dumps("ping").encode()] - - -def test_publish_streaming_response_coerces_string_uuid(mock_topic: MagicMock): - workflow_run_id = uuid.uuid4() - response_stream = iter([{"event": "bar"}]) - - _publish_streaming_response(response_stream, str(workflow_run_id), app_mode=AppMode.ADVANCED_CHAT) - - mock_topic.publish.assert_called_once_with(json.dumps({"event": "bar"}).encode()) diff --git a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py deleted file mode 100644 index fd5f0713a4..0000000000 --- a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py +++ /dev/null @@ -1,488 +0,0 @@ -# """ -# Unit tests for workflow node execution Celery tasks. - -# These tests verify the asynchronous storage functionality for workflow node execution data, -# including truncation and offloading logic. -# """ - -# import json -# from unittest.mock import MagicMock, Mock, patch -# from uuid import uuid4 - -# import pytest - -# from core.workflow.entities.workflow_node_execution import ( -# WorkflowNodeExecution, -# WorkflowNodeExecutionStatus, -# ) -# from core.workflow.enums import NodeType -# from libs.datetime_utils import naive_utc_now -# from models import WorkflowNodeExecutionModel -# from models.enums import ExecutionOffLoadType -# from models.model import UploadFile -# from models.workflow import WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom -# from tasks.workflow_node_execution_tasks import ( -# _create_truncator, -# _json_encode, -# _replace_or_append_offload, -# _truncate_and_upload_async, -# save_workflow_node_execution_data_task, -# save_workflow_node_execution_task, -# ) - - -# @pytest.fixture -# def sample_execution_data(): -# """Sample execution data for testing.""" -# execution = WorkflowNodeExecution( -# id=str(uuid4()), -# node_execution_id=str(uuid4()), -# workflow_id=str(uuid4()), -# workflow_execution_id=str(uuid4()), -# index=1, -# node_id="test_node", -# node_type=NodeType.LLM, -# title="Test Node", -# inputs={"input_key": "input_value"}, -# outputs={"output_key": "output_value"}, -# process_data={"process_key": "process_value"}, -# status=WorkflowNodeExecutionStatus.RUNNING, -# created_at=naive_utc_now(), -# ) -# return execution.model_dump() - - -# @pytest.fixture -# def mock_db_model(): -# """Mock database model for testing.""" -# db_model = Mock(spec=WorkflowNodeExecutionModel) -# db_model.id = "test-execution-id" -# db_model.offload_data = [] -# return db_model - - -# @pytest.fixture -# def mock_file_service(): -# """Mock file service for testing.""" -# file_service = Mock() -# mock_upload_file = Mock(spec=UploadFile) -# mock_upload_file.id = "mock-file-id" -# file_service.upload_file.return_value = mock_upload_file -# return file_service - - -# class TestSaveWorkflowNodeExecutionDataTask: -# """Test cases for save_workflow_node_execution_data_task.""" - -# @patch("tasks.workflow_node_execution_tasks.sessionmaker") -# @patch("tasks.workflow_node_execution_tasks.select") -# def test_save_execution_data_task_success( -# self, mock_select, mock_sessionmaker, sample_execution_data, mock_db_model -# ): -# """Test successful execution of save_workflow_node_execution_data_task.""" -# # Setup mocks -# mock_session = MagicMock() -# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session -# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model - -# # Execute task -# result = save_workflow_node_execution_data_task( -# execution_data=sample_execution_data, -# tenant_id="test-tenant-id", -# app_id="test-app-id", -# user_data={"user_id": "test-user-id", "user_type": "account"}, -# ) - -# # Verify success -# assert result is True -# mock_session.merge.assert_called_once_with(mock_db_model) -# mock_session.commit.assert_called_once() - -# @patch("tasks.workflow_node_execution_tasks.sessionmaker") -# @patch("tasks.workflow_node_execution_tasks.select") -# def test_save_execution_data_task_execution_not_found(self, mock_select, mock_sessionmaker, -# sample_execution_data): -# """Test task when execution is not found in database.""" -# # Setup mocks -# mock_session = MagicMock() -# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session -# mock_session.execute.return_value.scalars.return_value.first.return_value = None - -# # Execute task -# result = save_workflow_node_execution_data_task( -# execution_data=sample_execution_data, -# tenant_id="test-tenant-id", -# app_id="test-app-id", -# user_data={"user_id": "test-user-id", "user_type": "account"}, -# ) - -# # Verify failure -# assert result is False -# mock_session.merge.assert_not_called() -# mock_session.commit.assert_not_called() - -# @patch("tasks.workflow_node_execution_tasks.sessionmaker") -# @patch("tasks.workflow_node_execution_tasks.select") -# def test_save_execution_data_task_with_truncation(self, mock_select, mock_sessionmaker, mock_db_model): -# """Test task with data that requires truncation.""" -# # Create execution with large data -# large_data = {"large_field": "x" * 10000} -# execution = WorkflowNodeExecution( -# id=str(uuid4()), -# node_execution_id=str(uuid4()), -# workflow_id=str(uuid4()), -# workflow_execution_id=str(uuid4()), -# index=1, -# node_id="test_node", -# node_type=NodeType.LLM, -# title="Test Node", -# inputs=large_data, -# outputs=large_data, -# process_data=large_data, -# status=WorkflowNodeExecutionStatus.RUNNING, -# created_at=naive_utc_now(), -# ) -# execution_data = execution.model_dump() - -# # Setup mocks -# mock_session = MagicMock() -# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session -# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model - -# # Create mock upload file -# mock_upload_file = Mock(spec=UploadFile) -# mock_upload_file.id = "mock-file-id" - -# # Execute task -# with patch("tasks.workflow_node_execution_tasks._truncate_and_upload_async") as mock_truncate: -# # Mock truncation results -# mock_truncate.return_value = { -# "truncated_value": {"large_field": "[TRUNCATED]"}, -# "file": mock_upload_file, -# "offload": WorkflowNodeExecutionOffload( -# id=str(uuid4()), -# tenant_id="test-tenant-id", -# app_id="test-app-id", -# node_execution_id=execution.id, -# type_=ExecutionOffLoadType.INPUTS, -# file_id=mock_upload_file.id, -# ), -# } - -# result = save_workflow_node_execution_data_task( -# execution_data=execution_data, -# tenant_id="test-tenant-id", -# app_id="test-app-id", -# user_data={"user_id": "test-user-id", "user_type": "account"}, -# ) - -# # Verify success and truncation was called -# assert result is True -# assert mock_truncate.call_count == 3 # inputs, outputs, process_data -# mock_session.merge.assert_called_once_with(mock_db_model) -# mock_session.commit.assert_called_once() - -# @patch("tasks.workflow_node_execution_tasks.sessionmaker") -# def test_save_execution_data_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data): -# """Test task retry mechanism on exception.""" -# # Setup mock to raise exception -# mock_sessionmaker.side_effect = Exception("Database error") - -# # Create a mock task instance with proper retry behavior -# with patch.object(save_workflow_node_execution_data_task, "retry") as mock_retry: -# mock_retry.side_effect = Exception("Retry called") - -# # Execute task and expect retry -# with pytest.raises(Exception, match="Retry called"): -# save_workflow_node_execution_data_task( -# execution_data=sample_execution_data, -# tenant_id="test-tenant-id", -# app_id="test-app-id", -# user_data={"user_id": "test-user-id", "user_type": "account"}, -# ) - -# # Verify retry was called -# mock_retry.assert_called_once() - - -# class TestTruncateAndUploadAsync: -# """Test cases for _truncate_and_upload_async function.""" - -# def test_truncate_and_upload_with_none_values(self, mock_file_service): -# """Test _truncate_and_upload_async with None values.""" -# # The function handles None values internally, so we test with empty dict instead -# result = _truncate_and_upload_async( -# values={}, -# execution_id="test-id", -# type_=ExecutionOffLoadType.INPUTS, -# tenant_id="test-tenant", -# app_id="test-app", -# user_data={"user_id": "test-user", "user_type": "account"}, -# file_service=mock_file_service, -# ) - -# # Empty dict should not require truncation -# assert result is None -# mock_file_service.upload_file.assert_not_called() - -# @patch("tasks.workflow_node_execution_tasks._create_truncator") -# def test_truncate_and_upload_no_truncation_needed(self, mock_create_truncator, mock_file_service): -# """Test _truncate_and_upload_async when no truncation is needed.""" -# # Mock truncator to return no truncation -# mock_truncator = Mock() -# mock_truncator.truncate_variable_mapping.return_value = ({"small": "data"}, False) -# mock_create_truncator.return_value = mock_truncator - -# small_values = {"small": "data"} -# result = _truncate_and_upload_async( -# values=small_values, -# execution_id="test-id", -# type_=ExecutionOffLoadType.INPUTS, -# tenant_id="test-tenant", -# app_id="test-app", -# user_data={"user_id": "test-user", "user_type": "account"}, -# file_service=mock_file_service, -# ) - -# assert result is None -# mock_file_service.upload_file.assert_not_called() - -# @patch("tasks.workflow_node_execution_tasks._create_truncator") -# @patch("models.Account") -# @patch("models.Tenant") -# def test_truncate_and_upload_with_account_user( -# self, mock_tenant_class, mock_account_class, mock_create_truncator, mock_file_service -# ): -# """Test _truncate_and_upload_async with account user.""" -# # Mock truncator to return truncation needed -# mock_truncator = Mock() -# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True) -# mock_create_truncator.return_value = mock_truncator - -# # Mock user and tenant creation -# mock_account = Mock() -# mock_account.id = "test-user" -# mock_account_class.return_value = mock_account - -# mock_tenant = Mock() -# mock_tenant.id = "test-tenant" -# mock_tenant_class.return_value = mock_tenant - -# large_values = {"large": "x" * 10000} -# result = _truncate_and_upload_async( -# values=large_values, -# execution_id="test-id", -# type_=ExecutionOffLoadType.INPUTS, -# tenant_id="test-tenant", -# app_id="test-app", -# user_data={"user_id": "test-user", "user_type": "account"}, -# file_service=mock_file_service, -# ) - -# # Verify result structure -# assert result is not None -# assert "truncated_value" in result -# assert "file" in result -# assert "offload" in result -# assert result["truncated_value"] == {"truncated": "data"} - -# # Verify file upload was called -# mock_file_service.upload_file.assert_called_once() -# upload_call = mock_file_service.upload_file.call_args -# assert upload_call[1]["filename"] == "node_execution_test-id_inputs.json" -# assert upload_call[1]["mimetype"] == "application/json" -# assert upload_call[1]["user"] == mock_account - -# @patch("tasks.workflow_node_execution_tasks._create_truncator") -# @patch("models.EndUser") -# def test_truncate_and_upload_with_end_user(self, mock_end_user_class, mock_create_truncator, mock_file_service): -# """Test _truncate_and_upload_async with end user.""" -# # Mock truncator to return truncation needed -# mock_truncator = Mock() -# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True) -# mock_create_truncator.return_value = mock_truncator - -# # Mock end user creation -# mock_end_user = Mock() -# mock_end_user.id = "test-user" -# mock_end_user.tenant_id = "test-tenant" -# mock_end_user_class.return_value = mock_end_user - -# large_values = {"large": "x" * 10000} -# result = _truncate_and_upload_async( -# values=large_values, -# execution_id="test-id", -# type_=ExecutionOffLoadType.OUTPUTS, -# tenant_id="test-tenant", -# app_id="test-app", -# user_data={"user_id": "test-user", "user_type": "end_user"}, -# file_service=mock_file_service, -# ) - -# # Verify result structure -# assert result is not None -# assert result["truncated_value"] == {"truncated": "data"} - -# # Verify file upload was called with end user -# mock_file_service.upload_file.assert_called_once() -# upload_call = mock_file_service.upload_file.call_args -# assert upload_call[1]["filename"] == "node_execution_test-id_outputs.json" -# assert upload_call[1]["user"] == mock_end_user - - -# class TestHelperFunctions: -# """Test cases for helper functions.""" - -# @patch("tasks.workflow_node_execution_tasks.dify_config") -# def test_create_truncator(self, mock_config): -# """Test _create_truncator function.""" -# mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000 -# mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100 -# mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500 - -# truncator = _create_truncator() - -# # Verify truncator was created with correct config -# assert truncator is not None - -# def test_json_encode(self): -# """Test _json_encode function.""" -# test_data = {"key": "value", "number": 42} -# result = _json_encode(test_data) - -# assert isinstance(result, str) -# decoded = json.loads(result) -# assert decoded == test_data - -# def test_replace_or_append_offload_replace_existing(self): -# """Test _replace_or_append_offload replaces existing offload of same type.""" -# existing_offload = WorkflowNodeExecutionOffload( -# id=str(uuid4()), -# tenant_id="test-tenant", -# app_id="test-app", -# node_execution_id="test-execution", -# type_=ExecutionOffLoadType.INPUTS, -# file_id="old-file-id", -# ) - -# new_offload = WorkflowNodeExecutionOffload( -# id=str(uuid4()), -# tenant_id="test-tenant", -# app_id="test-app", -# node_execution_id="test-execution", -# type_=ExecutionOffLoadType.INPUTS, -# file_id="new-file-id", -# ) - -# result = _replace_or_append_offload([existing_offload], new_offload) - -# assert len(result) == 1 -# assert result[0].file_id == "new-file-id" - -# def test_replace_or_append_offload_append_new_type(self): -# """Test _replace_or_append_offload appends new offload of different type.""" -# existing_offload = WorkflowNodeExecutionOffload( -# id=str(uuid4()), -# tenant_id="test-tenant", -# app_id="test-app", -# node_execution_id="test-execution", -# type_=ExecutionOffLoadType.INPUTS, -# file_id="inputs-file-id", -# ) - -# new_offload = WorkflowNodeExecutionOffload( -# id=str(uuid4()), -# tenant_id="test-tenant", -# app_id="test-app", -# node_execution_id="test-execution", -# type_=ExecutionOffLoadType.OUTPUTS, -# file_id="outputs-file-id", -# ) - -# result = _replace_or_append_offload([existing_offload], new_offload) - -# assert len(result) == 2 -# file_ids = [offload.file_id for offload in result] -# assert "inputs-file-id" in file_ids -# assert "outputs-file-id" in file_ids - - -# class TestSaveWorkflowNodeExecutionTask: -# """Test cases for save_workflow_node_execution_task.""" - -# @patch("tasks.workflow_node_execution_tasks.sessionmaker") -# @patch("tasks.workflow_node_execution_tasks.select") -# def test_save_workflow_node_execution_task_create_new(self, mock_select, mock_sessionmaker, -# sample_execution_data): -# """Test creating a new workflow node execution.""" -# # Setup mocks -# mock_session = MagicMock() -# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session -# mock_session.scalar.return_value = None # No existing execution - -# # Execute task -# result = save_workflow_node_execution_task( -# execution_data=sample_execution_data, -# tenant_id="test-tenant-id", -# app_id="test-app-id", -# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, -# creator_user_id="test-user-id", -# creator_user_role="account", -# ) - -# # Verify success -# assert result is True -# mock_session.add.assert_called_once() -# mock_session.commit.assert_called_once() - -# @patch("tasks.workflow_node_execution_tasks.sessionmaker") -# @patch("tasks.workflow_node_execution_tasks.select") -# def test_save_workflow_node_execution_task_update_existing( -# self, mock_select, mock_sessionmaker, sample_execution_data -# ): -# """Test updating an existing workflow node execution.""" -# # Setup mocks -# mock_session = MagicMock() -# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session - -# existing_execution = Mock(spec=WorkflowNodeExecutionModel) -# mock_session.scalar.return_value = existing_execution - -# # Execute task -# result = save_workflow_node_execution_task( -# execution_data=sample_execution_data, -# tenant_id="test-tenant-id", -# app_id="test-app-id", -# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, -# creator_user_id="test-user-id", -# creator_user_role="account", -# ) - -# # Verify success -# assert result is True -# mock_session.add.assert_not_called() # Should not add new, just update existing -# mock_session.commit.assert_called_once() - -# @patch("tasks.workflow_node_execution_tasks.sessionmaker") -# def test_save_workflow_node_execution_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data): -# """Test task retry mechanism on exception.""" -# # Setup mock to raise exception -# mock_sessionmaker.side_effect = Exception("Database error") - -# # Create a mock task instance with proper retry behavior -# with patch.object(save_workflow_node_execution_task, "retry") as mock_retry: -# mock_retry.side_effect = Exception("Retry called") - -# # Execute task and expect retry -# with pytest.raises(Exception, match="Retry called"): -# save_workflow_node_execution_task( -# execution_data=sample_execution_data, -# tenant_id="test-tenant-id", -# app_id="test-app-id", -# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, -# creator_user_id="test-user-id", -# creator_user_role="account", -# ) - -# # Verify retry was called -# mock_retry.assert_called_once() diff --git a/api/ty.toml b/api/ty.toml index 6869ca98c4..afdd37897e 100644 --- a/api/ty.toml +++ b/api/ty.toml @@ -26,20 +26,5 @@ exclude = [ # non-producition or generated code "migrations", "tests", - # targeted ignores for current type-check errors - # TODO(QuantumGhost): suppress type errors in HITL related code. - # fix the type error later - "configs/middleware/cache/redis_pubsub_config.py", - "extensions/ext_redis.py", - "models/execution_extra_content.py", - "tasks/workflow_execution_tasks.py", - "core/workflow/nodes/base/node.py", - "services/human_input_delivery_test_service.py", - "core/app/apps/advanced_chat/app_generator.py", - "controllers/console/human_input_form.py", - "controllers/console/app/workflow_run.py", - "repositories/sqlalchemy_api_workflow_node_execution_repository.py", - "extensions/logstore/repositories/logstore_api_workflow_run_repository.py", - "controllers/web/workflow_events.py", - "tasks/app_generate/workflow_execute_task.py", ] + diff --git a/docker/.env.example b/docker/.env.example index 93099347bd..41a0205bf5 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1399,9 +1399,9 @@ PLUGIN_STDIO_BUFFER_SIZE=1024 PLUGIN_STDIO_MAX_BUFFER_SIZE=5242880 PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120 -# Plugin Daemon side timeout (configure to match the API side below) +# Plugin Daemon side timeout (configure to match the API side below) PLUGIN_MAX_EXECUTION_TIMEOUT=600 -# API side timeout (configure to match the Plugin Daemon side above) +# API side timeout (configure to match the Plugin Daemon side above) PLUGIN_DAEMON_TIMEOUT=600.0 # PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple PIP_MIRROR_URL= @@ -1519,31 +1519,4 @@ AMPLITUDE_API_KEY= SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 - - -# Redis URL used for PubSub between API and -# celery worker -# defaults to url constructed from `REDIS_*` -# configurations -PUBSUB_REDIS_URL= -# Pub/sub channel type for streaming events. -# valid options are: -# -# - pubsub: for normal Pub/Sub -# - sharded: for sharded Pub/Sub -# -# It's highly recommended to use sharded Pub/Sub AND redis cluster -# for large deployments. -PUBSUB_REDIS_CHANNEL_TYPE=pubsub -# Whether to use Redis cluster mode while running -# PubSub. -# It's highly recommended to enable this for large deployments. -PUBSUB_REDIS_USE_CLUSTERS=false - -# Whether to Enable human input timeout check task -ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true -# Human input timeout check interval in minutes -HUMAN_INPUT_TIMEOUT_TASK_INTERVAL=1 - - SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index a5518ceee9..a0a755f570 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -683,11 +683,6 @@ x-shared-env: &shared-api-worker-env SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: ${SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD:-21} SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: ${SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE:-1000} SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: ${SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS:-30} - PUBSUB_REDIS_URL: ${PUBSUB_REDIS_URL:-} - PUBSUB_REDIS_CHANNEL_TYPE: ${PUBSUB_REDIS_CHANNEL_TYPE:-pubsub} - PUBSUB_REDIS_USE_CLUSTERS: ${PUBSUB_REDIS_USE_CLUSTERS:-false} - ENABLE_HUMAN_INPUT_TIMEOUT_TASK: ${ENABLE_HUMAN_INPUT_TIMEOUT_TASK:-true} - HUMAN_INPUT_TIMEOUT_TASK_INTERVAL: ${HUMAN_INPUT_TIMEOUT_TASK_INTERVAL:-1} SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: ${SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL:-90000} services: diff --git a/web/__mocks__/provider-context.ts b/web/__mocks__/provider-context.ts index d3296bacd0..373c2f86d3 100644 --- a/web/__mocks__/provider-context.ts +++ b/web/__mocks__/provider-context.ts @@ -35,7 +35,6 @@ export const baseProviderContextValue: ProviderContextState = { refreshLicenseLimit: noop, isAllowTransferWorkspace: false, isAllowPublishAsCustomKnowledgePipelineTemplate: false, - humanInputEmailDeliveryEnabled: false, } export const createMockProviderContextValue = (overrides: Partial = {}): ProviderContextState => { diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx index fffc1ff2a5..fc27f84c60 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx @@ -8,6 +8,7 @@ describe('SVG Attribute Error Reproduction', () => { // Capture console errors const originalError = console.error let errorMessages: string[] = [] + beforeEach(() => { errorMessages = [] console.error = vi.fn((message) => { diff --git a/web/app/(humanInputLayout)/form/[token]/form.tsx b/web/app/(humanInputLayout)/form/[token]/form.tsx deleted file mode 100644 index d027ef8b7d..0000000000 --- a/web/app/(humanInputLayout)/form/[token]/form.tsx +++ /dev/null @@ -1,289 +0,0 @@ -'use client' -import type { ButtonProps } from '@/app/components/base/button' -import type { FormInputItem, UserAction } from '@/app/components/workflow/nodes/human-input/types' -import type { SiteInfo } from '@/models/share' -import type { HumanInputFormError } from '@/service/use-share' -import { - RiCheckboxCircleFill, - RiErrorWarningFill, - RiInformation2Fill, -} from '@remixicon/react' -import { produce } from 'immer' -import { useParams } from 'next/navigation' -import * as React from 'react' -import { useEffect, useMemo, useState } from 'react' -import { useTranslation } from 'react-i18next' -import AppIcon from '@/app/components/base/app-icon' -import Button from '@/app/components/base/button' -import ContentItem from '@/app/components/base/chat/chat/answer/human-input-content/content-item' -import ExpirationTime from '@/app/components/base/chat/chat/answer/human-input-content/expiration-time' -import { getButtonStyle } from '@/app/components/base/chat/chat/answer/human-input-content/utils' -import Loading from '@/app/components/base/loading' -import DifyLogo from '@/app/components/base/logo/dify-logo' -import useDocumentTitle from '@/hooks/use-document-title' -import { useGetHumanInputForm, useSubmitHumanInputForm } from '@/service/use-share' -import { cn } from '@/utils/classnames' - -export type FormData = { - site: { site: SiteInfo } - form_content: string - inputs: FormInputItem[] - resolved_default_values: Record - user_actions: UserAction[] - expiration_time: number -} - -const FormContent = () => { - const { t } = useTranslation() - - const { token } = useParams<{ token: string }>() - useDocumentTitle('') - - const [inputs, setInputs] = useState>({}) - const [success, setSuccess] = useState(false) - - const { mutate: submitForm, isPending: isSubmitting } = useSubmitHumanInputForm() - - const { data: formData, isLoading, error } = useGetHumanInputForm(token) - - const expired = (error as HumanInputFormError | null)?.code === 'human_input_form_expired' - const submitted = (error as HumanInputFormError | null)?.code === 'human_input_form_submitted' - const rateLimitExceeded = (error as HumanInputFormError | null)?.code === 'web_form_rate_limit_exceeded' - - const splitByOutputVar = (content: string): string[] => { - const outputVarRegex = /(\{\{#\$output\.[^#]+#\}\})/g - const parts = content.split(outputVarRegex) - return parts.filter(part => part.length > 0) - } - - const contentList = useMemo(() => { - if (!formData?.form_content) - return [] - return splitByOutputVar(formData.form_content) - }, [formData?.form_content]) - - useEffect(() => { - if (!formData?.inputs) - return - const initialInputs: Record = {} - formData.inputs.forEach((item) => { - initialInputs[item.output_variable_name] = item.default.type === 'variable' ? formData.resolved_default_values[item.output_variable_name] || '' : item.default.value - }) - setInputs(initialInputs) - }, [formData?.inputs, formData?.resolved_default_values]) - - // use immer - const handleInputsChange = (name: string, value: string) => { - const newInputs = produce(inputs, (draft) => { - draft[name] = value - }) - setInputs(newInputs) - } - - const submit = (actionID: string) => { - submitForm( - { token, data: { inputs, action: actionID } }, - { - onSuccess: () => { - setSuccess(true) - }, - }, - ) - } - - if (isLoading) { - return ( - - ) - } - - if (success) { - return ( -
-
-
-
- -
-
-
{t('humanInput.thanks', { ns: 'share' })}
-
{t('humanInput.recorded', { ns: 'share' })}
-
-
{t('humanInput.submissionID', { id: token, ns: 'share' })}
-
-
-
-
{t('chat.poweredBy', { ns: 'share' })}
- -
-
-
-
- ) - } - - if (expired) { - return ( -
-
-
-
- -
-
-
{t('humanInput.sorry', { ns: 'share' })}
-
{t('humanInput.expired', { ns: 'share' })}
-
-
{t('humanInput.submissionID', { id: token, ns: 'share' })}
-
-
-
-
{t('chat.poweredBy', { ns: 'share' })}
- -
-
-
-
- ) - } - - if (submitted) { - return ( -
-
-
-
- -
-
-
{t('humanInput.sorry', { ns: 'share' })}
-
{t('humanInput.completed', { ns: 'share' })}
-
-
{t('humanInput.submissionID', { id: token, ns: 'share' })}
-
-
-
-
{t('chat.poweredBy', { ns: 'share' })}
- -
-
-
-
- ) - } - - if (rateLimitExceeded) { - return ( -
-
-
-
- -
-
-
{t('humanInput.rateLimitExceeded', { ns: 'share' })}
-
-
-
-
-
{t('chat.poweredBy', { ns: 'share' })}
- -
-
-
-
- ) - } - - if (!formData) { - return ( -
-
-
-
- -
-
-
{t('humanInput.formNotFound', { ns: 'share' })}
-
-
-
-
-
{t('chat.poweredBy', { ns: 'share' })}
- -
-
-
-
- ) - } - - const site = formData.site.site - - return ( -
-
- -
{site.title}
-
-
-
- {contentList.map((content, index) => ( - - ))} -
- {formData.user_actions.map((action: UserAction) => ( - - ))} -
- -
-
-
-
{t('chat.poweredBy', { ns: 'share' })}
- -
-
-
-
- ) -} - -export default React.memo(FormContent) diff --git a/web/app/(humanInputLayout)/form/[token]/page.tsx b/web/app/(humanInputLayout)/form/[token]/page.tsx deleted file mode 100644 index a7e2305b2b..0000000000 --- a/web/app/(humanInputLayout)/form/[token]/page.tsx +++ /dev/null @@ -1,13 +0,0 @@ -'use client' -import * as React from 'react' -import FormContent from './form' - -const FormPage = () => { - return ( -
- -
- ) -} - -export default React.memo(FormPage) diff --git a/web/app/(shareLayout)/components/authenticated-layout.tsx b/web/app/(shareLayout)/components/authenticated-layout.tsx index c874990448..113f3b5680 100644 --- a/web/app/(shareLayout)/components/authenticated-layout.tsx +++ b/web/app/(shareLayout)/components/authenticated-layout.tsx @@ -47,7 +47,7 @@ const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => { await webAppLogout(shareCode!) const url = getSigninUrl() router.replace(url) - }, [getSigninUrl, router, shareCode]) + }, [getSigninUrl, router, webAppLogout, shareCode]) if (appInfoError) { return ( diff --git a/web/app/(shareLayout)/components/splash.tsx b/web/app/(shareLayout)/components/splash.tsx index a2b847f74f..9f89a03993 100644 --- a/web/app/(shareLayout)/components/splash.tsx +++ b/web/app/(shareLayout)/components/splash.tsx @@ -31,7 +31,7 @@ const Splash: FC = ({ children }) => { await webAppLogout(shareCode!) const url = getSigninUrl() router.replace(url) - }, [getSigninUrl, router, shareCode]) + }, [getSigninUrl, router, webAppLogout, shareCode]) const [isLoading, setIsLoading] = useState(true) useEffect(() => { diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index 1348e3111f..0fc364cb7e 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -115,7 +115,6 @@ export type AppPublisherProps = { missingStartNode?: boolean hasTriggerNode?: boolean // Whether workflow currently contains any trigger nodes (used to hide missing-start CTA when triggers exist). startNodeLimitExceeded?: boolean - hasHumanInputNode?: boolean } const PUBLISH_SHORTCUT = ['ctrl', '⇧', 'P'] @@ -139,14 +138,13 @@ const AppPublisher = ({ missingStartNode = false, hasTriggerNode = false, startNodeLimitExceeded = false, - hasHumanInputNode = false, }: AppPublisherProps) => { const { t } = useTranslation() const [published, setPublished] = useState(false) const [open, setOpen] = useState(false) const [showAppAccessControl, setShowAppAccessControl] = useState(false) - + const [isAppAccessSet, setIsAppAccessSet] = useState(true) const [embeddingModalOpen, setEmbeddingModalOpen] = useState(false) const appDetail = useAppStore(state => state.appDetail) @@ -163,13 +161,6 @@ const AppPublisher = ({ const { data: appAccessSubjects, isLoading: isGettingAppWhiteListSubjects } = useAppWhiteListSubjects(appDetail?.id, open && systemFeatures.webapp_auth.enabled && appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS) const openAsyncWindow = useAsyncWindowOpen() - const isAppAccessSet = useMemo(() => { - if (appDetail && appAccessSubjects) { - return !(appDetail.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS && appAccessSubjects.groups?.length === 0 && appAccessSubjects.members?.length === 0) - } - return true - }, [appAccessSubjects, appDetail]) - const noAccessPermission = useMemo(() => systemFeatures.webapp_auth.enabled && appDetail && appDetail.access_mode !== AccessMode.EXTERNAL_MEMBERS && !userCanAccessApp?.result, [systemFeatures, appDetail, userCanAccessApp]) const disabledFunctionButton = useMemo(() => (!publishedAt || missingStartNode || noAccessPermission), [publishedAt, missingStartNode, noAccessPermission]) @@ -180,13 +171,25 @@ const AppPublisher = ({ return t('noUserInputNode', { ns: 'app' }) if (noAccessPermission) return t('noAccessPermission', { ns: 'app' }) - }, [missingStartNode, noAccessPermission, publishedAt, t]) + }, [missingStartNode, noAccessPermission, publishedAt]) useEffect(() => { if (systemFeatures.webapp_auth.enabled && open && appDetail) refetch() }, [open, appDetail, refetch, systemFeatures]) + useEffect(() => { + if (appDetail && appAccessSubjects) { + if (appDetail.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS && appAccessSubjects.groups?.length === 0 && appAccessSubjects.members?.length === 0) + setIsAppAccessSet(false) + else + setIsAppAccessSet(true) + } + else { + setIsAppAccessSet(true) + } + }, [appAccessSubjects, appDetail]) + const handlePublish = useCallback(async (params?: ModelAndParameter | PublishWorkflowParams) => { try { await onPublish?.(params) @@ -458,7 +461,7 @@ const AppPublisher = ({ {t('common.accessAPIReference', { ns: 'workflow' })} - {appDetail?.mode === AppModeEnum.WORKFLOW && !hasHumanInputNode && ( + {appDetail?.mode === AppModeEnum.WORKFLOW && ( { if (!statusCount) return null - if (statusCount.paused > 0) { - return ( -
- - Pending -
- ) - } - else if (statusCount.partial_success + statusCount.failed === 0) { + if (statusCount.partial_success + statusCount.failed === 0) { return (
@@ -305,7 +296,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { if (abortControllerRef.current === controller) abortControllerRef.current = null } - }, [detail.id, hasMore, timezone, t, appDetail]) + }, [detail.id, hasMore, timezone, t, appDetail, detail?.model_config?.configs?.introduction]) // Derive chatItemTree, threadChatItems, and oldestAnswerIdRef from allChatItems useEffect(() => { @@ -420,7 +411,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) return false } - }, [allChatItems, appDetail?.id, notify, t]) + }, [allChatItems, appDetail?.id, t]) const fetchInitiated = useRef(false) @@ -513,7 +504,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { finally { setIsLoading(false) } - }, [detail.id, hasMore, isLoading, timezone, t, appDetail]) + }, [detail.id, hasMore, isLoading, timezone, t, appDetail, detail?.model_config?.configs?.introduction]) const handleScroll = useCallback(() => { const scrollableDiv = document.getElementById('scrollableDiv') diff --git a/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx b/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx index 54763907df..17857ec702 100644 --- a/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx +++ b/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx @@ -53,7 +53,6 @@ const defaultProviderContext = { refreshLicenseLimit: noop, isAllowTransferWorkspace: false, isAllowPublishAsCustomKnowledgePipelineTemplate: false, - humanInputEmailDeliveryEnabled: false, } const defaultModalContext: ModalContextState = { diff --git a/web/app/components/app/text-generate/item/index.tsx b/web/app/components/app/text-generate/item/index.tsx index 22358805a7..c39282a022 100644 --- a/web/app/components/app/text-generate/item/index.tsx +++ b/web/app/components/app/text-generate/item/index.tsx @@ -8,7 +8,7 @@ import { RiClipboardLine, RiFileList3Line, RiPlayList2Line, - RiResetLeftLine, + RiReplay15Line, RiSparklingFill, RiSparklingLine, RiThumbDownLine, @@ -18,12 +18,10 @@ import { useBoolean } from 'ahooks' import copy from 'copy-to-clipboard' import { useParams } from 'next/navigation' import * as React from 'react' -import { useCallback, useEffect, useState } from 'react' +import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { useStore as useAppStore } from '@/app/components/app/store' import ActionButton, { ActionButtonState } from '@/app/components/base/action-button' -import HumanInputFilledFormList from '@/app/components/base/chat/chat/answer/human-input-filled-form-list' -import HumanInputFormList from '@/app/components/base/chat/chat/answer/human-input-form-list' import WorkflowProcessItem from '@/app/components/base/chat/chat/answer/workflow-process' import { useChatContext } from '@/app/components/base/chat/chat/context' import Loading from '@/app/components/base/loading' @@ -31,8 +29,7 @@ import { Markdown } from '@/app/components/base/markdown' import NewAudioButton from '@/app/components/base/new-audio-button' import Toast from '@/app/components/base/toast' import { fetchTextGenerationMessage } from '@/service/debug' -import { AppSourceType, fetchMoreLikeThis, submitHumanInputForm, updateFeedback } from '@/service/share' -import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow' +import { AppSourceType, fetchMoreLikeThis, updateFeedback } from '@/service/share' import { cn } from '@/utils/classnames' import ResultTab from './result-tab' @@ -124,7 +121,7 @@ const GenerationItem: FC = ({ const [isQuerying, { setTrue: startQuerying, setFalse: stopQuerying }] = useBoolean(false) const childProps = { - isInWebApp, + isInWebApp: true, content: completionRes, messageId: childMessageId, depth: depth + 1, @@ -205,22 +202,16 @@ const GenerationItem: FC = ({ } const [currentTab, setCurrentTab] = useState('DETAIL') - const showResultTabs = !!workflowProcessData?.resultText || !!workflowProcessData?.files?.length || (workflowProcessData?.humanInputFormDataList && workflowProcessData?.humanInputFormDataList.length > 0) || (workflowProcessData?.humanInputFilledFormDataList && workflowProcessData?.humanInputFilledFormDataList.length > 0) + const showResultTabs = !!workflowProcessData?.resultText || !!workflowProcessData?.files?.length const switchTab = async (tab: string) => { setCurrentTab(tab) } useEffect(() => { - if (workflowProcessData?.resultText || !!workflowProcessData?.files?.length || (workflowProcessData?.humanInputFormDataList && workflowProcessData?.humanInputFormDataList.length > 0) || (workflowProcessData?.humanInputFilledFormDataList && workflowProcessData?.humanInputFilledFormDataList.length > 0)) + if (workflowProcessData?.resultText || !!workflowProcessData?.files?.length) switchTab('RESULT') else switchTab('DETAIL') - }, [workflowProcessData?.files?.length, workflowProcessData?.resultText, workflowProcessData?.humanInputFormDataList, workflowProcessData?.humanInputFilledFormDataList]) - const handleSubmitHumanInputForm = useCallback(async (formToken: string, formData: { inputs: Record, action: string }) => { - if (appSourceType === AppSourceType.installedApp) - await submitHumanInputFormService(formToken, formData) - else - await submitHumanInputForm(formToken, formData) - }, [appSourceType]) + }, [workflowProcessData?.files?.length, workflowProcessData?.resultText]) return ( <> @@ -284,24 +275,7 @@ const GenerationItem: FC = ({ )}
{!isError && ( - <> - {currentTab === 'RESULT' && workflowProcessData.humanInputFormDataList && workflowProcessData.humanInputFormDataList.length > 0 && ( -
- -
- )} - {currentTab === 'RESULT' && workflowProcessData.humanInputFilledFormDataList && workflowProcessData.humanInputFilledFormDataList.length > 0 && ( -
- -
- )} - - + )} )} @@ -374,7 +348,7 @@ const GenerationItem: FC = ({ )} {isInWebApp && isError && ( - + )} {isInWebApp && !isWorkflow && !isTryApp && ( diff --git a/web/app/components/app/workflow-log/list.tsx b/web/app/components/app/workflow-log/list.tsx index 262efad781..b9597c8ea1 100644 --- a/web/app/components/app/workflow-log/list.tsx +++ b/web/app/components/app/workflow-log/list.tsx @@ -81,14 +81,6 @@ const WorkflowAppLogList: FC = ({ logs, appDetail, onRefresh }) => { ) } - if (status === 'paused') { - return ( -
- - Pending -
- ) - } if (status === 'running') { return (
diff --git a/web/app/components/base/action-button/index.css b/web/app/components/base/action-button/index.css index 4ede34aeb5..3c1a10b86f 100644 --- a/web/app/components/base/action-button/index.css +++ b/web/app/components/base/action-button/index.css @@ -26,10 +26,6 @@ @apply p-0.5 w-6 h-6 rounded-lg } - .action-btn-s { - @apply w-5 h-5 rounded-[6px] - } - .action-btn-xs { @apply p-0 w-4 h-4 rounded } diff --git a/web/app/components/base/action-button/index.tsx b/web/app/components/base/action-button/index.tsx index d182193b00..c91d472087 100644 --- a/web/app/components/base/action-button/index.tsx +++ b/web/app/components/base/action-button/index.tsx @@ -18,7 +18,6 @@ const actionButtonVariants = cva( variants: { size: { xs: 'action-btn-xs', - s: 'action-btn-s', m: 'action-btn-m', l: 'action-btn-l', xl: 'action-btn-xl', diff --git a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx index 304425b9a7..38a3f6c6b2 100644 --- a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx +++ b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx @@ -2,7 +2,6 @@ import type { FileEntity } from '../../file-uploader/types' import type { ChatConfig, ChatItem, - ChatItemInTree, OnSend, } from '../types' import { useCallback, useEffect, useMemo, useState } from 'react' @@ -17,9 +16,7 @@ import { fetchSuggestedQuestions, getUrl, stopChatMessageResponding, - submitHumanInputForm, } from '@/service/share' -import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow' import { TransferMethod } from '@/types/app' import { cn } from '@/utils/classnames' import { formatBooleanInputs } from '@/utils/model-config' @@ -76,9 +73,9 @@ const ChatWrapper = () => { }, [appParams, currentConversationItem?.introduction]) const { chatList, + setTargetMessageId, handleSend, handleStop, - handleSwitchSibling, isResponding: respondingState, suggestedQuestions, } = useChat( @@ -125,11 +122,8 @@ const ChatWrapper = () => { if (fileIsUploading) return true - - if (chatList.some(item => item.isAnswer && item.humanInputFormDataList && item.humanInputFormDataList.length > 0)) - return true return false - }, [allInputsHidden, inputsForms, chatList, inputsFormValue]) + }, [inputsFormValue, inputsForms, allInputsHidden]) useEffect(() => { if (currentChatInstanceRef.current) @@ -140,40 +134,6 @@ const ChatWrapper = () => { setIsResponding(respondingState) }, [respondingState, setIsResponding]) - // Resume paused workflows when chat history is loaded - useEffect(() => { - if (!appPrevChatTree || appPrevChatTree.length === 0) - return - - // Find the last answer item with workflow_run_id that needs resumption (DFS - find deepest first) - let lastPausedNode: ChatItemInTree | undefined - const findLastPausedWorkflow = (nodes: ChatItemInTree[]) => { - nodes.forEach((node) => { - // DFS: recurse to children first - if (node.children && node.children.length > 0) - findLastPausedWorkflow(node.children) - - // Track the last node with humanInputFormDataList - if (node.isAnswer && node.workflow_run_id && node.humanInputFormDataList && node.humanInputFormDataList.length > 0) - lastPausedNode = node - }) - } - - findLastPausedWorkflow(appPrevChatTree) - - // Only resume the last paused workflow - if (lastPausedNode) { - handleSwitchSibling( - lastPausedNode.id, - { - onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, appSourceType, appId), - onConversationComplete: currentConversationId ? undefined : handleNewConversationCompleted, - isPublicAPI: appSourceType === AppSourceType.webApp, - }, - ) - } - }, []) - const doSend: OnSend = useCallback((message, files, isRegenerate = false, parentAnswer: ChatItem | null = null) => { const data: any = { query: message, @@ -189,10 +149,10 @@ const ChatWrapper = () => { { onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, appSourceType, appId), onConversationComplete: isHistoryConversation ? undefined : handleNewConversationCompleted, - isPublicAPI: appSourceType === AppSourceType.webApp, + isPublicAPI: !isInstalledApp, }, ) - }, [inputsForms, currentConversationId, currentConversationInputs, newConversationInputs, chatList, handleSend, appSourceType, appId, isHistoryConversation, handleNewConversationCompleted]) + }, [chatList, handleNewConversationCompleted, handleSend, currentConversationId, currentConversationInputs, newConversationInputs, isInstalledApp, appId]) const doRegenerate = useCallback((chatItem: ChatItem, editedQuestion?: { message: string, files?: FileEntity[] }) => { const question = editedQuestion ? chatItem : chatList.find(item => item.id === chatItem.parentMessageId)! @@ -200,27 +160,12 @@ const ChatWrapper = () => { doSend(editedQuestion ? editedQuestion.message : question.content, editedQuestion ? editedQuestion.files : question.message_files, true, isValidGeneratedAnswer(parentAnswer) ? parentAnswer : null) }, [chatList, doSend]) - const doSwitchSibling = useCallback((siblingMessageId: string) => { - handleSwitchSibling(siblingMessageId, { - onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, appSourceType, appId), - onConversationComplete: currentConversationId ? undefined : handleNewConversationCompleted, - isPublicAPI: appSourceType === AppSourceType.webApp, - }) - }, [handleSwitchSibling, currentConversationId, handleNewConversationCompleted, appSourceType, appId]) - const messageList = useMemo(() => { if (currentConversationId || chatList.length > 1) return chatList // Without messages we are in the welcome screen, so hide the opening statement from chatlist return chatList.filter(item => !item.isOpeningStatement) - }, [chatList, currentConversationId]) - - const handleSubmitHumanInputForm = useCallback(async (formToken: string, formData: any) => { - if (isInstalledApp) - await submitHumanInputFormService(formToken, formData) - else - await submitHumanInputForm(formToken, formData) - }, [isInstalledApp]) + }, [chatList]) const [collapsed, setCollapsed] = useState(!!currentConversationId) @@ -329,7 +274,6 @@ const ChatWrapper = () => { inputsForm={inputsForms} onRegenerate={doRegenerate} onStopResponding={handleStop} - onHumanInputFormSubmit={handleSubmitHumanInputForm} chatNode={( <> {chatNode} @@ -342,7 +286,7 @@ const ChatWrapper = () => { answerIcon={answerIcon} hideProcessDetail themeBuilder={themeBuilder} - switchSibling={doSwitchSibling} + switchSibling={siblingMessageId => setTargetMessageId(siblingMessageId)} inputDisabled={inputDisabled} sidebarCollapseState={sidebarCollapseState} questionIcon={ diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index da344a9789..ad1de38d07 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -1,4 +1,3 @@ -import type { ExtraContent } from '../chat/type' import type { Callback, ChatConfig, @@ -10,7 +9,6 @@ import type { AppData, ConversationItem, } from '@/models/share' -import type { HumanInputFilledFormData, HumanInputFormData } from '@/types/workflow' import { useLocalStorageState } from 'ahooks' import { noop } from 'es-toolkit/function' import { produce } from 'immer' @@ -59,24 +57,6 @@ function getFormattedChatList(messages: any[]) { parentMessageId: item.parent_message_id || undefined, }) const answerFiles = item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [] - const humanInputFormDataList: HumanInputFormData[] = [] - const humanInputFilledFormDataList: HumanInputFilledFormData[] = [] - let workflowRunId = '' - if (item.status === 'paused') { - item.extra_contents?.forEach((content: ExtraContent) => { - if (content.type === 'human_input' && !content.submitted) { - humanInputFormDataList.push(content.form_definition) - workflowRunId = content.workflow_run_id - } - }) - } - else if (item.status === 'normal') { - item.extra_contents?.forEach((content: ExtraContent) => { - if (content.type === 'human_input' && content.submitted) { - humanInputFilledFormDataList.push(content.form_submission_data) - } - }) - } newChatList.push({ id: item.id, content: item.answer, @@ -86,9 +66,6 @@ function getFormattedChatList(messages: any[]) { citation: item.retriever_resources, message_files: getProcessedFilesFromResponse(answerFiles.map((item: any) => ({ ...item, related_id: item.id, upload_file_id: item.upload_file_id }))), parentMessageId: `question-${item.id}`, - humanInputFormDataList, - humanInputFilledFormDataList, - workflow_run_id: workflowRunId, }) }) return newChatList diff --git a/web/app/components/base/chat/chat/answer/human-input-content/content-item.tsx b/web/app/components/base/chat/chat/answer/human-input-content/content-item.tsx deleted file mode 100644 index 3ed777d41e..0000000000 --- a/web/app/components/base/chat/chat/answer/human-input-content/content-item.tsx +++ /dev/null @@ -1,54 +0,0 @@ -import type { ContentItemProps } from './type' -import * as React from 'react' -import { useMemo } from 'react' -import { Markdown } from '@/app/components/base/markdown' -import Textarea from '@/app/components/base/textarea' - -const ContentItem = ({ - content, - formInputFields, - inputs, - onInputChange, -}: ContentItemProps) => { - const isInputField = (field: string) => { - const outputVarRegex = /\{\{#\$output\.[^#]+#\}\}/ - return outputVarRegex.test(field) - } - - const extractFieldName = (str: string): string => { - const outputVarRegex = /\{\{#\$output\.([^#]+)#\}\}/ - const match = str.match(outputVarRegex) - return match ? match[1] : '' - } - - const fieldName = useMemo(() => { - return extractFieldName(content) - }, [content]) - - const formInputField = useMemo(() => { - return formInputFields.find(field => field.output_variable_name === fieldName) - }, [formInputFields, fieldName]) - - if (!isInputField(content)) { - return ( - - ) - } - - if (!formInputField) - return null - - return ( -
- {formInputField.type === 'paragraph' && ( -