diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 0879fb9896..4ddfcb921f 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -13,18 +13,21 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import Forbidden -from controllers.console import api, console_ns +from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from controllers.web.error import NotFoundError +from controllers.web.error import InvalidArgumentError, NotFoundError +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.message_generator import MessageGenerator +from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.workflow.nodes.human_input.entities import FormDefinition from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required -from models.account import Account +from models import App from models.enums import CreatorUserRole from models.human_input import HumanInputForm as HumanInputFormModel -from models.model import App, EndUser -from models.workflow import WorkflowRun +from models.model import AppMode +from models.workflow import Workflow, WorkflowRun from repositories.factory import DifyAPIRepositoryFactory from services.human_input_service import HumanInputService @@ -64,9 +67,6 @@ class ConsoleHumanInputFormApi(Resource): if form_model is None or form_model.tenant_id != current_tenant_id: raise NotFoundError(f"form not found, id={form_id}") - from models import App - from models.workflow import Workflow, WorkflowRun - workflow_run = db.session.get(WorkflowRun, form_model.workflow_run_id) if workflow_run is None or workflow_run.tenant_id != current_tenant_id: raise NotFoundError("Workflow run not found") @@ -159,7 +159,7 @@ class ConsoleWorkflowEventsApi(Resource): creator_user=user, ) - # We'll + # TODO: should we just return here? or yield a WorkflowFinishStreamResponse? def generate_events() -> Generator[str, None, None]: """Generate SSE events for workflow execution.""" try: @@ -181,10 +181,18 @@ class ConsoleWorkflowEventsApi(Resource): yield f"data: {{'error': 'Stream error: {str(e)}'}}\n\n" else: # TODO: SSE from Redis PubSub - queue = ... + msg_generator = MessageGenerator() + if app.mode == AppMode.ADVANCED_CHAT: + generator = AdvancedChatAppGenerator() + elif app.mode == AppMode.WORKFLOW: + generator = WorkflowAppGenerator() + else: + raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") def generate_events(): - yield from [] + return generator.convert_to_event_stream( + msg_generator.retrieve_events(AppMode(app.mode), workflow_run.id), + ) return Response( generate_events(), @@ -210,7 +218,6 @@ class ConsoleWorkflowPauseDetailsApi(Resource): Returns information about why and where the workflow is paused. """ - from models.workflow import WorkflowRun # Query WorkflowRun to determine if workflow is suspended workflow_run = db.session.get(WorkflowRun, workflow_run_id) @@ -270,3 +277,5 @@ def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun): f"App not found for WorkflowRun, workflow_run_id={workflow_run.id}, " f"app_id={workflow_run.app_id}, tenant_id={workflow_run.tenant_id}" ) + + return app diff --git a/api/controllers/web/human_input_form.py b/api/controllers/web/human_input_form.py index f2023feb33..3c7ea428e4 100644 --- a/api/controllers/web/human_input_form.py +++ b/api/controllers/web/human_input_form.py @@ -14,13 +14,6 @@ from controllers.web.error import ( from controllers.web.wraps import WebApiResource from extensions.ext_database import db from models.human_input import HumanInputSubmissionType -from services.human_input_form_service import ( - HumanInputFormAlreadySubmittedError, - HumanInputFormExpiredError, - HumanInputFormNotFoundError, - HumanInputFormService, - InvalidFormDataError, -) logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index e1a9e38166..b817d0df15 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -3,7 +3,7 @@ import time from collections.abc import Mapping, Sequence from dataclasses import dataclass from datetime import datetime -from typing import Any, NamedTuple, NewType, Union, final +from typing import Any, NewType, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index f27819c34c..e3e5c0ea64 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -2,12 +2,9 @@ import json import logging import time import uuid -from collections.abc import Generator -from typing import Any, Mapping, Union, cast +from collections.abc import Generator, Mapping +from typing import Any, Union, cast -from libs.broadcast_channel.channel import Subscription, Topic -from libs.broadcast_channel.exc import SubscriptionClosedError -from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel from sqlalchemy import select from sqlalchemy.orm import Session @@ -34,6 +31,9 @@ from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBa from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.broadcast_channel.channel import Topic +from libs.broadcast_channel.exc import SubscriptionClosedError +from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel from libs.datetime_utils import naive_utc_now from models import Account from models.enums import CreatorUserRole diff --git a/api/core/app/apps/message_generator.py b/api/core/app/apps/message_generator.py new file mode 100644 index 0000000000..a943d65523 --- /dev/null +++ b/api/core/app/apps/message_generator.py @@ -0,0 +1,59 @@ +import json +import time +from collections.abc import Generator, Mapping +from typing import Any + +from core.app.entities.task_entities import ( + StreamEvent, +) +from extensions.ext_redis import redis_client +from libs.broadcast_channel.channel import Topic +from libs.broadcast_channel.exc import SubscriptionClosedError +from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel +from models.model import AppMode + + +class MessageGenerator: + @staticmethod + def _make_channel_key(app_mode: AppMode, workflow_run_id: str): + return f"channel:{app_mode}:{str(workflow_run_id)}" + + @classmethod + def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic: + key = cls._make_channel_key(app_mode, workflow_run_id) + channel = RedisBroadcastChannel(redis_client) + topic = channel.topic(key) + return topic + + @classmethod + def retrieve_events( + cls, app_mode: AppMode, workflow_run_id: str, idle_timeout=300 + ) -> Generator[Mapping | str, None, None]: + topic = cls.get_response_topic(app_mode, workflow_run_id) + return _topic_msg_generator(topic, idle_timeout) + + +def _topic_msg_generator(topic: Topic, idle_timeout: float) -> Generator[Mapping[str, Any], None, None]: + last_msg_time = time.time() + with topic.subscribe() as sub: + while True: + try: + msg = sub.receive() + except SubscriptionClosedError: + return + if msg is None: + current_time = time.time() + if current_time - last_msg_time > idle_timeout: + return + # skip the `None` message + continue + + last_msg_time = time.time() + event = json.loads(msg) + yield event + if not isinstance(event, dict): + continue + + event_type = event.get("event") + if event_type in (StreamEvent.WORKFLOW_FINISHED, StreamEvent.WORKFLOW_PAUSED): + return diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 1ba7fdf921..0cb573cb86 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from enum import StrEnum -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Optional from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator diff --git a/api/core/repositories/human_input_reposotiry.py b/api/core/repositories/human_input_reposotiry.py index 1ab09b91cd..3513a5da99 100644 --- a/api/core/repositories/human_input_reposotiry.py +++ b/api/core/repositories/human_input_reposotiry.py @@ -1,12 +1,10 @@ -import abc import dataclasses import json -import uuid -from collections.abc import Sequence -from typing import Any, Mapping +from collections.abc import Mapping, Sequence +from typing import Any from sqlalchemy import Engine, select -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import sessionmaker from core.workflow.nodes.human_input.entities import ( DeliveryChannelConfig, diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 0a35dad5eb..837af562fb 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -8,10 +8,10 @@ from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union import redis from redis import RedisError from redis.cache import CacheConfig +from redis.client import PubSub from redis.cluster import ClusterNode, RedisCluster from redis.connection import Connection, SSLConnection from redis.sentinel import Sentinel -from redis.client import PubSub from configs import dify_config from dify_app import DifyApp diff --git a/api/models/human_input.py b/api/models/human_input.py index 7bfae33050..23e71778b3 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -1,6 +1,6 @@ from datetime import datetime from enum import StrEnum -from typing import Annotated, Any, ClassVar, Literal, Self, final +from typing import Annotated, Literal, Self, final import sqlalchemy as sa from pydantic import BaseModel, Field @@ -8,7 +8,6 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship from core.workflow.nodes.human_input.entities import ( DeliveryMethodType, - EmailRecipientType, HumanInputFormStatus, ) from libs.helper import generate_string diff --git a/api/models/workflow.py b/api/models/workflow.py index c73d39009a..ef8d557f72 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -5,7 +5,6 @@ from datetime import datetime from enum import StrEnum from typing import TYPE_CHECKING, Any, Optional, Union, cast from uuid import uuid4 -from typing_extensions import deprecated import sqlalchemy as sa from sqlalchemy import ( @@ -21,6 +20,7 @@ from sqlalchemy import ( select, ) from sqlalchemy.orm import Mapped, declared_attr, mapped_column +from typing_extensions import deprecated from core.file.constants import maybe_file_object from core.file.models import File diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index 0a3b5edfbd..57acc9725e 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -1,9 +1,9 @@ import contextlib import logging +import uuid from collections.abc import Mapping from enum import StrEnum -from typing import Annotated, Any, Generator, Literal, Self, TypeAlias, Union, overload -import uuid +from typing import Annotated, Any, TypeAlias, Union from celery import shared_task from flask import current_app, json @@ -12,13 +12,10 @@ from sqlalchemy import Engine from sqlalchemy.orm import sessionmaker from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator -from core.app.apps.base_app_generator import BaseAppGenerator from core.app.entities.app_invoke_entities import ( InvokeFrom, ) from extensions.ext_database import db -from extensions.ext_redis import redis_client -from libs.broadcast_channel.redis import BroadcastChannel as RedisBroadcastChannel from libs.flask_utils import set_login_user from models.account import Account from models.model import App, AppMode, EndUser diff --git a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/__init__.py b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/__init__.py index f81efa1bd0..e3f0d8a96e 100644 --- a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/__init__.py +++ b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/__init__.py @@ -5,6 +5,14 @@ This module provides utility classes and functions for testing Redis broadcast channel functionality. """ +from .test_data import ( + LARGE_MESSAGES, + SMALL_MESSAGES, + SPECIAL_MESSAGES, + BufferTestConfig, + ConcurrencyTestConfig, + ErrorTestConfig, +) from .test_helpers import ( ConcurrentPublisher, SubscriptionMonitor, @@ -12,25 +20,17 @@ from .test_helpers import ( measure_throughput, wait_for_condition, ) -from .test_data import ( - BufferTestConfig, - ConcurrencyTestConfig, - ErrorTestConfig, - LARGE_MESSAGES, - SMALL_MESSAGES, - SPECIAL_MESSAGES, -) __all__ = [ + "LARGE_MESSAGES", + "SMALL_MESSAGES", + "SPECIAL_MESSAGES", + "BufferTestConfig", + "ConcurrencyTestConfig", "ConcurrentPublisher", + "ErrorTestConfig", "SubscriptionMonitor", "assert_message_order", "measure_throughput", "wait_for_condition", - "BufferTestConfig", - "ConcurrencyTestConfig", - "ErrorTestConfig", - "LARGE_MESSAGES", - "SMALL_MESSAGES", - "SPECIAL_MESSAGES", ] diff --git a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_data.py b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_data.py index 45aee96616..2cccb08304 100644 --- a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_data.py +++ b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_data.py @@ -75,7 +75,7 @@ VERY_LARGE_MESSAGES = [ SPECIAL_MESSAGES = [ b"", # Empty message b"\x00\x01\x02", # Binary data with null bytes - "unicode_test_你好".encode("utf-8"), # Unicode + "unicode_test_你好".encode(), # Unicode b"special_chars_!@#$%^&*()_+-=[]{}|;':\",./<>?", # Special characters b"newlines\n\r\t", # Control characters ] @@ -241,8 +241,8 @@ EDGE_CASE_MESSAGES = [ b"\x00", # Single null byte b"\xff", # Single max byte value b"a", # Single ASCII character - "ä".encode("utf-8"), # Single unicode character (2 bytes) - "𐍈".encode("utf-8"), # Unicode character outside BMP (4 bytes) + "ä".encode(), # Single unicode character (2 bytes) + "𐍈".encode(), # Unicode character outside BMP (4 bytes) b"\x00" * 1000, # 1000 null bytes b"\xff" * 1000, # 1000 max byte values ] diff --git a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_helpers.py b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_helpers.py index 3901df3062..80895553cb 100644 --- a/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_helpers.py +++ b/api/tests/integration_tests/libs/broadcast_channel/redis/utils/test_helpers.py @@ -8,7 +8,7 @@ operations, monitoring subscriptions, and measuring performance. import logging import threading import time -from collections.abc import Callable, Generator +from collections.abc import Callable from typing import Any _logger = logging.getLogger(__name__) @@ -62,7 +62,7 @@ class ConcurrentPublisher: if self.delay > 0: time.sleep(self.delay) except Exception as e: - _logger.error(f"Publisher {thread_id} error: {e}") + _logger.error("Publisher %s error: %s", thread_id, e) with self._lock: self.published_messages.append(messages) @@ -280,7 +280,7 @@ def assert_message_order(received: list[bytes], expected: list[bytes]) -> bool: for i, (recv_msg, exp_msg) in enumerate(zip(received, expected)): if recv_msg != exp_msg: - _logger.error(f"Message order mismatch at index {i}: expected {exp_msg}, got {recv_msg}") + _logger.error("Message order mismatch at index %s: expected %s, got %s", i, exp_msg, recv_msg) return False return True @@ -309,7 +309,7 @@ def measure_throughput( operation() count += 1 except Exception as e: - _logger.error(f"Operation failed: {e}") + _logger.error("Operation failed: %s", e) break elapsed = time.time() - start_time diff --git a/api/tests/unit_tests/core/human_input_form_test.py b/api/tests/unit_tests/core/human_input_form_test.py index f4766c789f..25e1707871 100644 --- a/api/tests/unit_tests/core/human_input_form_test.py +++ b/api/tests/unit_tests/core/human_input_form_test.py @@ -3,13 +3,12 @@ Tests for HumanInputForm domain model and repository. """ import json -from datetime import datetime, timedelta +from datetime import datetime from unittest.mock import MagicMock, patch import pytest - -from core.workflow.entities.human_input_form import HumanInputForm, HumanInputFormStatus, HumanInputSubmissionType from core.repositories.sqlalchemy_human_input_form_repository import SQLAlchemyHumanInputFormRepository +from core.workflow.entities.human_input_form import HumanInputForm, HumanInputFormStatus class TestHumanInputForm: @@ -201,7 +200,11 @@ class TestSQLAlchemyHumanInputFormRepository: """Test converting DB model to domain model.""" from models.human_input import ( HumanInputForm as DBForm, + ) + from models.human_input import ( HumanInputFormStatus as DBStatus, + ) + from models.human_input import ( HumanInputSubmissionType as DBSubmissionType, ) @@ -233,9 +236,7 @@ class TestSQLAlchemyHumanInputFormRepository: def test_to_db_model(self, repository): """Test converting domain model to DB model.""" from models.human_input import ( - HumanInputForm as DBForm, HumanInputFormStatus as DBStatus, - HumanInputSubmissionType as DBSubmissionType, ) domain_form = HumanInputForm.create(