This commit is contained in:
QuantumGhost 2025-11-12 08:50:09 +08:00
parent 4f48b8a57d
commit e47059514a
5 changed files with 107 additions and 29 deletions

View File

@ -9,15 +9,23 @@ from collections.abc import Generator
from flask import Response, jsonify
from flask_restx import Resource, reqparse
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import NotFoundError
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.workflow.nodes.human_input.entities import FormDefinition
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models.account import Account
from models.enums import CreatorUserRole
from models.human_input import HumanInputForm as HumanInputFormModel
from models.model import App, EndUser
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.human_input_service import HumanInputService
logger = logging.getLogger(__name__)
@ -31,6 +39,7 @@ def _jsonify_pydantic_model(model: BaseModel) -> Response:
return Response(model.model_dump_json(), mimetype="application/json")
@console_ns.route("/form/human_input/<string:form_id>")
class ConsoleHumanInputFormApi(Resource):
"""Console API for getting human input form definition."""
@ -78,10 +87,6 @@ class ConsoleHumanInputFormApi(Resource):
return _jsonify_pydantic_model(form.get_definition())
class ConsoleHumanInputFormSubmissionApi(Resource):
"""Console API for submitting human input forms."""
@account_initialization_required
@login_required
def post(self, form_id: str):
@ -114,6 +119,7 @@ class ConsoleHumanInputFormSubmissionApi(Resource):
return jsonify({})
@console_ns.route("/workflow/<string:workflow_run_id>/events")
class ConsoleWorkflowEventsApi(Resource):
"""Console API for getting workflow execution events after resume."""
@ -128,27 +134,57 @@ class ConsoleWorkflowEventsApi(Resource):
Returns Server-Sent Events stream.
"""
events =
user, tenant_id = current_account_with_tenant()
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
tenant_id=tenant_id,
run_id=workflow_run_id,
)
if workflow_run is None:
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
def generate_events() -> Generator[str, None, None]:
"""Generate SSE events for workflow execution."""
try:
# TODO: Implement actual event streaming
# This would connect to the workflow execution engine
# and stream real-time events
if workflow_run.created_by_role != CreatorUserRole.ACCOUNT:
raise NotFoundError(f"WorkflowRun not created by account, id={workflow_run_id}")
# For demo purposes, send a basic event
yield f"data: {{'event': 'workflow_resumed', 'task_id': '{task_id}'}}\n\n"
if workflow_run.created_by != user.id:
raise NotFoundError(f"WorkflowRun not created by the current account, id={workflow_run_id}")
# In real implementation, this would:
# 1. Connect to workflow execution engine
# 2. Stream real-time execution events
# 3. Handle client disconnection
# 4. Clean up resources on completion
with Session(expire_on_commit=False, bind=db.engine) as session:
app = _retrieve_app_for_workflow_run(session, workflow_run)
except Exception as e:
logger.exception("Error streaming events for task %s", task_id)
yield f"data: {{'error': 'Stream error: {str(e)}'}}\n\n"
if workflow_run.finished_at is not None:
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
workflow_run=workflow_run,
creator_user=user,
)
# We'll
def generate_events() -> Generator[str, None, None]:
"""Generate SSE events for workflow execution."""
try:
# TODO: Implement actual event streaming
# This would connect to the workflow execution engine
# and stream real-time events
# For demo purposes, send a basic event
yield f"data: {{'event': 'workflow_resumed', 'task_id': '{task_id}'}}\n\n"
# In real implementation, this would:
# 1. Connect to workflow execution engine
# 2. Stream real-time execution events
# 3. Handle client disconnection
# 4. Clean up resources on completion
except Exception as e:
logger.exception("Error streaming events for task %s", task_id)
yield f"data: {{'error': 'Stream error: {str(e)}'}}\n\n"
else:
# TODO: SSE from Redis PubSub
queue = ...
def generate_events():
yield from []
return Response(
generate_events(),
@ -160,6 +196,7 @@ class ConsoleWorkflowEventsApi(Resource):
)
@console_ns.route("/workflow/<string:workflow_run_id>/pause-details")
class ConsoleWorkflowPauseDetailsApi(Resource):
"""Console API for getting workflow pause details."""
@ -222,8 +259,14 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
return response, 200
# Register the APIs
api.add_resource(ConsoleHumanInputFormApi, "/form/human_input/<string:form_id>")
api.add_resource(ConsoleHumanInputFormSubmissionApi, "/form/human_input/<string:form_id>", methods=["POST"])
api.add_resource(ConsoleWorkflowEventsApi, "/workflow/<string:workflow_run_id>/events")
api.add_resource(ConsoleWorkflowPauseDetailsApi, "/workflow/<string:workflow_run_id>/pause-details")
def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun):
query = select(App).where(
App.id == workflow_run.app_id,
App.tenant_id == workflow_run.tenant_id,
)
app = session.scalars(query).first()
if app is None:
raise AssertionError(
f"App not found for WorkflowRun, workflow_run_id={workflow_run.id}, "
f"app_id={workflow_run.app_id}, tenant_id={workflow_run.tenant_id}"
)

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator

View File

@ -5,6 +5,7 @@ from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import uuid4
from typing_extensions import deprecated
import sqlalchemy as sa
from sqlalchemy import (
@ -405,6 +406,11 @@ class Workflow(Base): # bug
return helper.generate_text_hash(json.dumps(entity, sort_keys=True))
@deprecated(
"This property is not accurate for determining if a workflow is published as a tool."
"It only checks if there's a WorkflowToolProvider for the app, "
"not if this specific workflow version is the one being used by the tool."
)
@property
def tool_published(self) -> bool:
"""
@ -616,7 +622,7 @@ class WorkflowRun(Base):
elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255)) # account, end_user
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
@ -631,11 +637,13 @@ class WorkflowRun(Base):
back_populates="workflow_run",
)
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
@property
def created_by_account(self):
created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
@property
def created_by_end_user(self):
from .model import EndUser
@ -655,6 +663,7 @@ class WorkflowRun(Base):
def outputs_dict(self) -> Mapping[str, Any]:
return json.loads(self.outputs) if self.outputs else {}
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
@property
def message(self):
from .model import Message
@ -663,6 +672,7 @@ class WorkflowRun(Base):
db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
)
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
@property
def workflow(self):
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()

View File

@ -479,3 +479,19 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
[{"date": "2024-01-01", "interactions": 2.5}, ...]
"""
...
def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None:
"""
Get a specific workflow run by its id and the associated tenant id.
This function does not apply application isolation. It should only be used when
the application identifier is not available.
Args:
tenant_id: Tenant identifier for multi-tenant isolation
run_id: Workflow run identifier
Returns:
WorkflowRun object if found, None otherwise
"""
...

View File

@ -839,6 +839,15 @@ GROUP BY
return cast(list[AverageInteractionStats], response_data)
def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None:
"""Get a specific workflow run by its id and the associated tenant id."""
with self._session_maker() as session:
stmt = select(WorkflowRun).where(
WorkflowRun.tenant_id == tenant_id,
WorkflowRun.id == run_id,
)
return session.scalar(stmt)
class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
"""