mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
WIP: P3
This commit is contained in:
parent
4f48b8a57d
commit
e47059514a
@ -9,15 +9,23 @@ from collections.abc import Generator
|
||||
from flask import Response, jsonify
|
||||
from flask_restx import Resource, reqparse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.human_input import HumanInputForm as HumanInputFormModel
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.human_input_service import HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -31,6 +39,7 @@ def _jsonify_pydantic_model(model: BaseModel) -> Response:
|
||||
return Response(model.model_dump_json(), mimetype="application/json")
|
||||
|
||||
|
||||
@console_ns.route("/form/human_input/<string:form_id>")
|
||||
class ConsoleHumanInputFormApi(Resource):
|
||||
"""Console API for getting human input form definition."""
|
||||
|
||||
@ -78,10 +87,6 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
|
||||
return _jsonify_pydantic_model(form.get_definition())
|
||||
|
||||
|
||||
class ConsoleHumanInputFormSubmissionApi(Resource):
|
||||
"""Console API for submitting human input forms."""
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def post(self, form_id: str):
|
||||
@ -114,6 +119,7 @@ class ConsoleHumanInputFormSubmissionApi(Resource):
|
||||
return jsonify({})
|
||||
|
||||
|
||||
@console_ns.route("/workflow/<string:workflow_run_id>/events")
|
||||
class ConsoleWorkflowEventsApi(Resource):
|
||||
"""Console API for getting workflow execution events after resume."""
|
||||
|
||||
@ -128,27 +134,57 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
Returns Server-Sent Events stream.
|
||||
"""
|
||||
|
||||
events =
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
tenant_id=tenant_id,
|
||||
run_id=workflow_run_id,
|
||||
)
|
||||
if workflow_run is None:
|
||||
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
|
||||
|
||||
def generate_events() -> Generator[str, None, None]:
|
||||
"""Generate SSE events for workflow execution."""
|
||||
try:
|
||||
# TODO: Implement actual event streaming
|
||||
# This would connect to the workflow execution engine
|
||||
# and stream real-time events
|
||||
if workflow_run.created_by_role != CreatorUserRole.ACCOUNT:
|
||||
raise NotFoundError(f"WorkflowRun not created by account, id={workflow_run_id}")
|
||||
|
||||
# For demo purposes, send a basic event
|
||||
yield f"data: {{'event': 'workflow_resumed', 'task_id': '{task_id}'}}\n\n"
|
||||
if workflow_run.created_by != user.id:
|
||||
raise NotFoundError(f"WorkflowRun not created by the current account, id={workflow_run_id}")
|
||||
|
||||
# In real implementation, this would:
|
||||
# 1. Connect to workflow execution engine
|
||||
# 2. Stream real-time execution events
|
||||
# 3. Handle client disconnection
|
||||
# 4. Clean up resources on completion
|
||||
with Session(expire_on_commit=False, bind=db.engine) as session:
|
||||
app = _retrieve_app_for_workflow_run(session, workflow_run)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error streaming events for task %s", task_id)
|
||||
yield f"data: {{'error': 'Stream error: {str(e)}'}}\n\n"
|
||||
if workflow_run.finished_at is not None:
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
workflow_run=workflow_run,
|
||||
creator_user=user,
|
||||
)
|
||||
|
||||
# We'll
|
||||
def generate_events() -> Generator[str, None, None]:
|
||||
"""Generate SSE events for workflow execution."""
|
||||
try:
|
||||
# TODO: Implement actual event streaming
|
||||
# This would connect to the workflow execution engine
|
||||
# and stream real-time events
|
||||
|
||||
# For demo purposes, send a basic event
|
||||
yield f"data: {{'event': 'workflow_resumed', 'task_id': '{task_id}'}}\n\n"
|
||||
|
||||
# In real implementation, this would:
|
||||
# 1. Connect to workflow execution engine
|
||||
# 2. Stream real-time execution events
|
||||
# 3. Handle client disconnection
|
||||
# 4. Clean up resources on completion
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error streaming events for task %s", task_id)
|
||||
yield f"data: {{'error': 'Stream error: {str(e)}'}}\n\n"
|
||||
else:
|
||||
# TODO: SSE from Redis PubSub
|
||||
queue = ...
|
||||
|
||||
def generate_events():
|
||||
yield from []
|
||||
|
||||
return Response(
|
||||
generate_events(),
|
||||
@ -160,6 +196,7 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workflow/<string:workflow_run_id>/pause-details")
|
||||
class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
"""Console API for getting workflow pause details."""
|
||||
|
||||
@ -222,8 +259,14 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
return response, 200
|
||||
|
||||
|
||||
# Register the APIs
|
||||
api.add_resource(ConsoleHumanInputFormApi, "/form/human_input/<string:form_id>")
|
||||
api.add_resource(ConsoleHumanInputFormSubmissionApi, "/form/human_input/<string:form_id>", methods=["POST"])
|
||||
api.add_resource(ConsoleWorkflowEventsApi, "/workflow/<string:workflow_run_id>/events")
|
||||
api.add_resource(ConsoleWorkflowPauseDetailsApi, "/workflow/<string:workflow_run_id>/pause-details")
|
||||
def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun):
|
||||
query = select(App).where(
|
||||
App.id == workflow_run.app_id,
|
||||
App.tenant_id == workflow_run.tenant_id,
|
||||
)
|
||||
app = session.scalars(query).first()
|
||||
if app is None:
|
||||
raise AssertionError(
|
||||
f"App not found for WorkflowRun, workflow_run_id={workflow_run.id}, "
|
||||
f"app_id={workflow_run.app_id}, tenant_id={workflow_run.tenant_id}"
|
||||
)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from uuid import uuid4
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import (
|
||||
@ -405,6 +406,11 @@ class Workflow(Base): # bug
|
||||
|
||||
return helper.generate_text_hash(json.dumps(entity, sort_keys=True))
|
||||
|
||||
@deprecated(
|
||||
"This property is not accurate for determining if a workflow is published as a tool."
|
||||
"It only checks if there's a WorkflowToolProvider for the app, "
|
||||
"not if this specific workflow version is the one being used by the tool."
|
||||
)
|
||||
@property
|
||||
def tool_published(self) -> bool:
|
||||
"""
|
||||
@ -616,7 +622,7 @@ class WorkflowRun(Base):
|
||||
elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
|
||||
total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255)) # account, end_user
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
@ -631,11 +637,13 @@ class WorkflowRun(Base):
|
||||
back_populates="workflow_run",
|
||||
)
|
||||
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
@property
|
||||
def created_by_account(self):
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
|
||||
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
@property
|
||||
def created_by_end_user(self):
|
||||
from .model import EndUser
|
||||
@ -655,6 +663,7 @@ class WorkflowRun(Base):
|
||||
def outputs_dict(self) -> Mapping[str, Any]:
|
||||
return json.loads(self.outputs) if self.outputs else {}
|
||||
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
@property
|
||||
def message(self):
|
||||
from .model import Message
|
||||
@ -663,6 +672,7 @@ class WorkflowRun(Base):
|
||||
db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
|
||||
)
|
||||
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
@property
|
||||
def workflow(self):
|
||||
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
||||
|
||||
@ -479,3 +479,19 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
||||
[{"date": "2024-01-01", "interactions": 2.5}, ...]
|
||||
"""
|
||||
...
|
||||
|
||||
def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None:
|
||||
"""
|
||||
Get a specific workflow run by its id and the associated tenant id.
|
||||
|
||||
This function does not apply application isolation. It should only be used when
|
||||
the application identifier is not available.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier for multi-tenant isolation
|
||||
run_id: Workflow run identifier
|
||||
|
||||
Returns:
|
||||
WorkflowRun object if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@ -839,6 +839,15 @@ GROUP BY
|
||||
|
||||
return cast(list[AverageInteractionStats], response_data)
|
||||
|
||||
def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None:
|
||||
"""Get a specific workflow run by its id and the associated tenant id."""
|
||||
with self._session_maker() as session:
|
||||
stmt = select(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == tenant_id,
|
||||
WorkflowRun.id == run_id,
|
||||
)
|
||||
return session.scalar(stmt)
|
||||
|
||||
|
||||
class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user