fix(api): fix race condition between workflow execution and SSE subscription

This commit is contained in:
QuantumGhost 2026-01-07 09:45:12 +08:00
parent 001d2c5062
commit 3c79bea28f
4 changed files with 85 additions and 14 deletions

View File

@ -2,7 +2,7 @@ import json
import logging
import time
import uuid
from collections.abc import Generator, Mapping
from collections.abc import Callable, Generator, Mapping
from typing import Any, Union, cast
from sqlalchemy import select
@ -309,15 +309,25 @@ class MessageBasedAppGenerator(BaseAppGenerator):
@classmethod
def retrieve_events(
cls, app_mode: AppMode, workflow_run_id: uuid.UUID, idle_timeout=300
cls,
app_mode: AppMode,
workflow_run_id: uuid.UUID,
idle_timeout=300,
on_subscribe: Callable[[], None] | None = None,
) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id)
return _topic_msg_generator(topic, idle_timeout)
return _topic_msg_generator(topic, idle_timeout, on_subscribe)
def _topic_msg_generator(topic: Topic, idle_timeout: float) -> Generator[Mapping[str, Any], None, None]:
def _topic_msg_generator(
topic: Topic,
idle_timeout: float,
on_subscribe: Callable[[], None] | None = None,
) -> Generator[Mapping[str, Any], None, None]:
last_msg_time = time.time()
with topic.subscribe() as sub:
if on_subscribe is not None:
on_subscribe()
while True:
try:
msg = sub.receive()

View File

@ -1,6 +1,6 @@
import json
import time
from collections.abc import Generator, Mapping
from collections.abc import Callable, Generator, Mapping
from typing import Any
from core.app.entities.task_entities import (
@ -27,15 +27,25 @@ class MessageGenerator:
@classmethod
def retrieve_events(
cls, app_mode: AppMode, workflow_run_id: str, idle_timeout=300
cls,
app_mode: AppMode,
workflow_run_id: str,
idle_timeout=300,
on_subscribe: Callable[[], None] | None = None,
) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id)
return _topic_msg_generator(topic, idle_timeout)
return _topic_msg_generator(topic, idle_timeout, on_subscribe)
def _topic_msg_generator(topic: Topic, idle_timeout: float) -> Generator[Mapping[str, Any], None, None]:
def _topic_msg_generator(
topic: Topic,
idle_timeout: float,
on_subscribe: Callable[[], None] | None = None,
) -> Generator[Mapping[str, Any], None, None]:
last_msg_time = time.time()
with topic.subscribe() as sub:
if on_subscribe is not None:
on_subscribe()
while True:
try:
msg = sub.receive()

View File

@ -4,7 +4,6 @@ from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,

View File

@ -1,5 +1,7 @@
import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from collections.abc import Callable, Generator, Mapping
from typing import Any, Union
from configs import dify_config
@ -20,8 +22,44 @@ from services.errors.app import InvokeRateLimitError, QuotaExceededError, Workfl
from services.workflow_service import WorkflowService
from tasks.app_generate.workflow_execute_task import AppExecutionParams, chatflow_execute_task
logger = logging.getLogger(__name__)
SSE_TASK_START_FALLBACK_MS = 200
class AppGenerateService:
@staticmethod
def _build_streaming_task_on_subscribe(start_task: Callable[[], None]) -> Callable[[], None]:
started = False
lock = threading.Lock()
def _try_start() -> bool:
nonlocal started
with lock:
if started:
return True
try:
start_task()
except Exception:
logger.exception("Failed to enqueue streaming task")
return False
started = True
return True
# XXX(QuantumGhost): dirty hacks to avoid a race between publisher and SSE subscriber.
# The Celery task may publish the first event before the API side actually subscribes,
# causing an "at most once" drop with Redis Pub/Sub. We start the task on subscribe,
# but also use a short fallback timer so the task still runs if the client never consumes.
timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start)
timer.daemon = True
timer.start()
def _on_subscribe() -> None:
if _try_start():
timer.cancel()
return _on_subscribe
@classmethod
@trace_span(AppGenerateHandler)
def generate(
@ -95,11 +133,18 @@ class AppGenerateService:
streaming=streaming,
call_depth=0,
)
chatflow_execute_task.delay(payload.model_dump_json())
payload_json = payload.model_dump_json()
on_subscribe = cls._build_streaming_task_on_subscribe(
lambda: chatflow_execute_task.delay(payload_json)
)
generator = AdvancedChatAppGenerator()
return rate_limit.generate(
generator.convert_to_event_stream(
generator.retrieve_events(AppMode.ADVANCED_CHAT, payload.workflow_run_id),
generator.retrieve_events(
AppMode.ADVANCED_CHAT,
payload.workflow_run_id,
on_subscribe=on_subscribe,
),
),
request_id=request_id,
)
@ -119,10 +164,17 @@ class AppGenerateService:
root_node_id=root_node_id,
workflow_run_id=uuid.uuid4(),
)
chatflow_execute_task.delay(payload.model_dump_json())
payload_json = payload.model_dump_json()
on_subscribe = cls._build_streaming_task_on_subscribe(
lambda: chatflow_execute_task.delay(payload_json)
)
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
MessageBasedAppGenerator.retrieve_events(AppMode.WORKFLOW, payload.workflow_run_id),
MessageBasedAppGenerator.retrieve_events(
AppMode.WORKFLOW,
payload.workflow_run_id,
on_subscribe=on_subscribe,
),
),
request_id,
)