diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index eb533df424..8c11739d10 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -2,7 +2,7 @@ import json import logging import time import uuid -from collections.abc import Generator, Mapping +from collections.abc import Callable, Generator, Mapping from typing import Any, Union, cast from sqlalchemy import select @@ -309,15 +309,25 @@ class MessageBasedAppGenerator(BaseAppGenerator): @classmethod def retrieve_events( - cls, app_mode: AppMode, workflow_run_id: uuid.UUID, idle_timeout=300 + cls, + app_mode: AppMode, + workflow_run_id: uuid.UUID, + idle_timeout=300, + on_subscribe: Callable[[], None] | None = None, ) -> Generator[Mapping | str, None, None]: topic = cls.get_response_topic(app_mode, workflow_run_id) - return _topic_msg_generator(topic, idle_timeout) + return _topic_msg_generator(topic, idle_timeout, on_subscribe) -def _topic_msg_generator(topic: Topic, idle_timeout: float) -> Generator[Mapping[str, Any], None, None]: +def _topic_msg_generator( + topic: Topic, + idle_timeout: float, + on_subscribe: Callable[[], None] | None = None, +) -> Generator[Mapping[str, Any], None, None]: last_msg_time = time.time() with topic.subscribe() as sub: + if on_subscribe is not None: + on_subscribe() while True: try: msg = sub.receive() diff --git a/api/core/app/apps/message_generator.py b/api/core/app/apps/message_generator.py index a943d65523..aeec3289a8 100644 --- a/api/core/app/apps/message_generator.py +++ b/api/core/app/apps/message_generator.py @@ -1,6 +1,6 @@ import json import time -from collections.abc import Generator, Mapping +from collections.abc import Callable, Generator, Mapping from typing import Any from core.app.entities.task_entities import ( @@ -27,15 +27,25 @@ class MessageGenerator: @classmethod def retrieve_events( - cls, app_mode: AppMode, workflow_run_id: str, idle_timeout=300 + cls, + app_mode: AppMode, + workflow_run_id: str, + idle_timeout=300, + on_subscribe: Callable[[], None] | None = None, ) -> Generator[Mapping | str, None, None]: topic = cls.get_response_topic(app_mode, workflow_run_id) - return _topic_msg_generator(topic, idle_timeout) + return _topic_msg_generator(topic, idle_timeout, on_subscribe) -def _topic_msg_generator(topic: Topic, idle_timeout: float) -> Generator[Mapping[str, Any], None, None]: +def _topic_msg_generator( + topic: Topic, + idle_timeout: float, + on_subscribe: Callable[[], None] | None = None, +) -> Generator[Mapping[str, Any], None, None]: last_msg_time = time.time() with topic.subscribe() as sub: + if on_subscribe is not None: + on_subscribe() while True: try: msg = sub.receive() diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index e93ce5e36e..b9eade0352 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -4,7 +4,6 @@ from collections.abc import Mapping, Sequence from typing import Any, cast from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, QueueAgentLogEvent, diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 1da0e32bc0..937da1848c 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -1,5 +1,7 @@ +import logging +import threading import uuid -from collections.abc import Generator, Mapping +from collections.abc import Callable, Generator, Mapping from typing import Any, Union from configs import dify_config @@ -20,8 +22,44 @@ from services.errors.app import InvokeRateLimitError, QuotaExceededError, Workfl from services.workflow_service import WorkflowService from tasks.app_generate.workflow_execute_task import AppExecutionParams, chatflow_execute_task +logger = logging.getLogger(__name__) + +SSE_TASK_START_FALLBACK_MS = 200 + class AppGenerateService: + @staticmethod + def _build_streaming_task_on_subscribe(start_task: Callable[[], None]) -> Callable[[], None]: + started = False + lock = threading.Lock() + + def _try_start() -> bool: + nonlocal started + with lock: + if started: + return True + try: + start_task() + except Exception: + logger.exception("Failed to enqueue streaming task") + return False + started = True + return True + + # XXX(QuantumGhost): dirty hacks to avoid a race between publisher and SSE subscriber. + # The Celery task may publish the first event before the API side actually subscribes, + # causing an "at most once" drop with Redis Pub/Sub. We start the task on subscribe, + # but also use a short fallback timer so the task still runs if the client never consumes. + timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start) + timer.daemon = True + timer.start() + + def _on_subscribe() -> None: + if _try_start(): + timer.cancel() + + return _on_subscribe + @classmethod @trace_span(AppGenerateHandler) def generate( @@ -95,11 +133,18 @@ class AppGenerateService: streaming=streaming, call_depth=0, ) - chatflow_execute_task.delay(payload.model_dump_json()) + payload_json = payload.model_dump_json() + on_subscribe = cls._build_streaming_task_on_subscribe( + lambda: chatflow_execute_task.delay(payload_json) + ) generator = AdvancedChatAppGenerator() return rate_limit.generate( generator.convert_to_event_stream( - generator.retrieve_events(AppMode.ADVANCED_CHAT, payload.workflow_run_id), + generator.retrieve_events( + AppMode.ADVANCED_CHAT, + payload.workflow_run_id, + on_subscribe=on_subscribe, + ), ), request_id=request_id, ) @@ -119,10 +164,17 @@ class AppGenerateService: root_node_id=root_node_id, workflow_run_id=uuid.uuid4(), ) - chatflow_execute_task.delay(payload.model_dump_json()) + payload_json = payload.model_dump_json() + on_subscribe = cls._build_streaming_task_on_subscribe( + lambda: chatflow_execute_task.delay(payload_json) + ) return rate_limit.generate( WorkflowAppGenerator.convert_to_event_stream( - MessageBasedAppGenerator.retrieve_events(AppMode.WORKFLOW, payload.workflow_run_id), + MessageBasedAppGenerator.retrieve_events( + AppMode.WORKFLOW, + payload.workflow_run_id, + on_subscribe=on_subscribe, + ), ), request_id, )