mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
fix(api): fix race condition between workflow execution and SSE subscription
This commit is contained in:
parent
001d2c5062
commit
3c79bea28f
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
@ -309,15 +309,25 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
@classmethod
|
||||
def retrieve_events(
|
||||
cls, app_mode: AppMode, workflow_run_id: uuid.UUID, idle_timeout=300
|
||||
cls,
|
||||
app_mode: AppMode,
|
||||
workflow_run_id: uuid.UUID,
|
||||
idle_timeout=300,
|
||||
on_subscribe: Callable[[], None] | None = None,
|
||||
) -> Generator[Mapping | str, None, None]:
|
||||
topic = cls.get_response_topic(app_mode, workflow_run_id)
|
||||
return _topic_msg_generator(topic, idle_timeout)
|
||||
return _topic_msg_generator(topic, idle_timeout, on_subscribe)
|
||||
|
||||
|
||||
def _topic_msg_generator(topic: Topic, idle_timeout: float) -> Generator[Mapping[str, Any], None, None]:
|
||||
def _topic_msg_generator(
|
||||
topic: Topic,
|
||||
idle_timeout: float,
|
||||
on_subscribe: Callable[[], None] | None = None,
|
||||
) -> Generator[Mapping[str, Any], None, None]:
|
||||
last_msg_time = time.time()
|
||||
with topic.subscribe() as sub:
|
||||
if on_subscribe is not None:
|
||||
on_subscribe()
|
||||
while True:
|
||||
try:
|
||||
msg = sub.receive()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.task_entities import (
|
||||
@ -27,15 +27,25 @@ class MessageGenerator:
|
||||
|
||||
@classmethod
|
||||
def retrieve_events(
|
||||
cls, app_mode: AppMode, workflow_run_id: str, idle_timeout=300
|
||||
cls,
|
||||
app_mode: AppMode,
|
||||
workflow_run_id: str,
|
||||
idle_timeout=300,
|
||||
on_subscribe: Callable[[], None] | None = None,
|
||||
) -> Generator[Mapping | str, None, None]:
|
||||
topic = cls.get_response_topic(app_mode, workflow_run_id)
|
||||
return _topic_msg_generator(topic, idle_timeout)
|
||||
return _topic_msg_generator(topic, idle_timeout, on_subscribe)
|
||||
|
||||
|
||||
def _topic_msg_generator(topic: Topic, idle_timeout: float) -> Generator[Mapping[str, Any], None, None]:
|
||||
def _topic_msg_generator(
|
||||
topic: Topic,
|
||||
idle_timeout: float,
|
||||
on_subscribe: Callable[[], None] | None = None,
|
||||
) -> Generator[Mapping[str, Any], None, None]:
|
||||
last_msg_time = time.time()
|
||||
with topic.subscribe() as sub:
|
||||
if on_subscribe is not None:
|
||||
on_subscribe()
|
||||
while True:
|
||||
try:
|
||||
msg = sub.receive()
|
||||
|
||||
@ -4,7 +4,6 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from configs import dify_config
|
||||
@ -20,8 +22,44 @@ from services.errors.app import InvokeRateLimitError, QuotaExceededError, Workfl
|
||||
from services.workflow_service import WorkflowService
|
||||
from tasks.app_generate.workflow_execute_task import AppExecutionParams, chatflow_execute_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SSE_TASK_START_FALLBACK_MS = 200
|
||||
|
||||
|
||||
class AppGenerateService:
|
||||
@staticmethod
|
||||
def _build_streaming_task_on_subscribe(start_task: Callable[[], None]) -> Callable[[], None]:
|
||||
started = False
|
||||
lock = threading.Lock()
|
||||
|
||||
def _try_start() -> bool:
|
||||
nonlocal started
|
||||
with lock:
|
||||
if started:
|
||||
return True
|
||||
try:
|
||||
start_task()
|
||||
except Exception:
|
||||
logger.exception("Failed to enqueue streaming task")
|
||||
return False
|
||||
started = True
|
||||
return True
|
||||
|
||||
# XXX(QuantumGhost): dirty hacks to avoid a race between publisher and SSE subscriber.
|
||||
# The Celery task may publish the first event before the API side actually subscribes,
|
||||
# causing an "at most once" drop with Redis Pub/Sub. We start the task on subscribe,
|
||||
# but also use a short fallback timer so the task still runs if the client never consumes.
|
||||
timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start)
|
||||
timer.daemon = True
|
||||
timer.start()
|
||||
|
||||
def _on_subscribe() -> None:
|
||||
if _try_start():
|
||||
timer.cancel()
|
||||
|
||||
return _on_subscribe
|
||||
|
||||
@classmethod
|
||||
@trace_span(AppGenerateHandler)
|
||||
def generate(
|
||||
@ -95,11 +133,18 @@ class AppGenerateService:
|
||||
streaming=streaming,
|
||||
call_depth=0,
|
||||
)
|
||||
chatflow_execute_task.delay(payload.model_dump_json())
|
||||
payload_json = payload.model_dump_json()
|
||||
on_subscribe = cls._build_streaming_task_on_subscribe(
|
||||
lambda: chatflow_execute_task.delay(payload_json)
|
||||
)
|
||||
generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
generator.convert_to_event_stream(
|
||||
generator.retrieve_events(AppMode.ADVANCED_CHAT, payload.workflow_run_id),
|
||||
generator.retrieve_events(
|
||||
AppMode.ADVANCED_CHAT,
|
||||
payload.workflow_run_id,
|
||||
on_subscribe=on_subscribe,
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
@ -119,10 +164,17 @@ class AppGenerateService:
|
||||
root_node_id=root_node_id,
|
||||
workflow_run_id=uuid.uuid4(),
|
||||
)
|
||||
chatflow_execute_task.delay(payload.model_dump_json())
|
||||
payload_json = payload.model_dump_json()
|
||||
on_subscribe = cls._build_streaming_task_on_subscribe(
|
||||
lambda: chatflow_execute_task.delay(payload_json)
|
||||
)
|
||||
return rate_limit.generate(
|
||||
WorkflowAppGenerator.convert_to_event_stream(
|
||||
MessageBasedAppGenerator.retrieve_events(AppMode.WORKFLOW, payload.workflow_run_id),
|
||||
MessageBasedAppGenerator.retrieve_events(
|
||||
AppMode.WORKFLOW,
|
||||
payload.workflow_run_id,
|
||||
on_subscribe=on_subscribe,
|
||||
),
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user