mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
WIP: feat(api): human input service
This commit is contained in:
parent
c7957d5740
commit
c0e15b9e1b
@ -34,6 +34,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _FormDefinitionWithSite(FormDefinition):
|
||||
# the site field may be not necessary for console scenario.
|
||||
site: None
|
||||
|
||||
|
||||
|
||||
@ -4,55 +4,47 @@ Web App Human Input Form APIs.
|
||||
|
||||
import logging
|
||||
|
||||
from flask import jsonify
|
||||
from flask_restful import reqparse
|
||||
from flask import Response
|
||||
from flask_restx import reqparse
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.error import (
|
||||
NotFoundError,
|
||||
)
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotFoundError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from models.human_input import HumanInputSubmissionType
|
||||
from models.human_input import RecipientType
|
||||
from models.model import App, EndUser
|
||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormApi(WebApiResource):
|
||||
"""API for getting human input form definition."""
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
"""Return the Pydantic definition as a JSON response."""
|
||||
return Response(form.get_definition().model_dump_json(), mimetype="application/json")
|
||||
|
||||
def get(self, web_app_form_token: str):
|
||||
|
||||
@web_ns.route("/form/human_input/<string:web_app_form_token>")
|
||||
class HumanInputFormApi(WebApiResource):
|
||||
"""API for getting and submitting human input forms via the web app."""
|
||||
|
||||
def get(self, _app_model: App, _end_user: EndUser, web_app_form_token: str):
|
||||
"""
|
||||
Get human input form definition by token.
|
||||
|
||||
GET /api/form/human_input/<web_app_form_token>
|
||||
"""
|
||||
service = HumanInputService(db.engine)
|
||||
try:
|
||||
service = HumanInputFormService(db.session())
|
||||
form_definition = service.get_form_definition(
|
||||
identifier=web_app_form_token, is_token=True, include_site_info=True
|
||||
)
|
||||
return form_definition, 200
|
||||
|
||||
except HumanInputFormNotFoundError:
|
||||
form = service.get_form_definition_by_token(RecipientType.WEBAPP, web_app_form_token)
|
||||
except FormNotFoundError:
|
||||
raise NotFoundError("Form not found")
|
||||
except HumanInputFormExpiredError:
|
||||
return jsonify(
|
||||
{"error_code": "human_input_form_expired", "description": "Human input form has expired"}
|
||||
), 400
|
||||
except HumanInputFormAlreadySubmittedError:
|
||||
return jsonify(
|
||||
{
|
||||
"error_code": "human_input_form_submitted",
|
||||
"description": "Human input form has already been submitted",
|
||||
}
|
||||
), 400
|
||||
|
||||
if form is None:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
class HumanInputFormSubmissionApi(WebApiResource):
|
||||
"""API for submitting human input forms."""
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
def post(self, web_app_form_token: str):
|
||||
def post(self, _app_model: App, _end_user: EndUser, web_app_form_token: str):
|
||||
"""
|
||||
Submit human input form by token.
|
||||
|
||||
@ -71,36 +63,15 @@ class HumanInputFormSubmissionApi(WebApiResource):
|
||||
parser.add_argument("action", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
try:
|
||||
# Submit the form
|
||||
service = HumanInputFormService(db.session())
|
||||
service.submit_form(
|
||||
identifier=web_app_form_token,
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
form_token=web_app_form_token,
|
||||
selected_action_id=args["action"],
|
||||
form_data=args["inputs"],
|
||||
action=args["action"],
|
||||
is_token=True,
|
||||
submission_type=HumanInputSubmissionType.web_app,
|
||||
)
|
||||
|
||||
return {}, 200
|
||||
|
||||
except HumanInputFormNotFoundError:
|
||||
except FormNotFoundError:
|
||||
raise NotFoundError("Form not found")
|
||||
except HumanInputFormExpiredError:
|
||||
return jsonify(
|
||||
{"error_code": "human_input_form_expired", "description": "Human input form has expired"}
|
||||
), 400
|
||||
except HumanInputFormAlreadySubmittedError:
|
||||
return jsonify(
|
||||
{
|
||||
"error_code": "human_input_form_submitted",
|
||||
"description": "Human input form has already been submitted",
|
||||
}
|
||||
), 400
|
||||
except InvalidFormDataError as e:
|
||||
return jsonify({"error_code": "invalid_form_data", "description": e.message}), 400
|
||||
|
||||
|
||||
# Register the APIs
|
||||
api.add_resource(HumanInputFormApi, "/form/human_input/<string:web_app_form_token>")
|
||||
api.add_resource(HumanInputFormSubmissionApi, "/form/human_input/<string:web_app_form_token>", methods=["POST"])
|
||||
return {}, 200
|
||||
|
||||
@ -1,19 +1,28 @@
|
||||
"""
|
||||
Repository implementations for data access.
|
||||
"""Repository implementations for data access."""
|
||||
|
||||
This package contains concrete implementations of the repository interfaces
|
||||
defined in the core.workflow.repository package.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from importlib import import_module
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"CeleryWorkflowExecutionRepository",
|
||||
"CeleryWorkflowNodeExecutionRepository",
|
||||
"DifyCoreRepositoryFactory",
|
||||
"RepositoryImportError",
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
]
|
||||
_ATTRIBUTE_MODULE_MAP = {
|
||||
"CeleryWorkflowExecutionRepository": "core.repositories.celery_workflow_execution_repository",
|
||||
"CeleryWorkflowNodeExecutionRepository": "core.repositories.celery_workflow_node_execution_repository",
|
||||
"DifyCoreRepositoryFactory": "core.repositories.factory",
|
||||
"RepositoryImportError": "core.repositories.factory",
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository": "core.repositories.sqlalchemy_workflow_node_execution_repository",
|
||||
}
|
||||
|
||||
__all__ = list(_ATTRIBUTE_MODULE_MAP.keys())
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
module_path = _ATTRIBUTE_MODULE_MAP.get(name)
|
||||
if module_path is None:
|
||||
raise AttributeError(f"module 'core.repositories' has no attribute '{name}'")
|
||||
module = import_module(module_path)
|
||||
return getattr(module, name)
|
||||
|
||||
|
||||
def __dir__() -> list[str]: # pragma: no cover - simple helper
|
||||
return sorted(__all__)
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session, sessionmaker, selectinload
|
||||
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
DeliveryChannelConfig,
|
||||
@ -19,17 +19,18 @@ from core.workflow.nodes.human_input.entities import (
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
FormNotFoundError,
|
||||
FormSubmissionEntity,
|
||||
FormSubmission,
|
||||
HumanInputFormEntity,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.account import Account, TenantAccountJoin
|
||||
from models.human_input import (
|
||||
EmailExternalRecipientPayload,
|
||||
EmailMemberRecipientPayload,
|
||||
HumanInputDelivery,
|
||||
HumanInputForm,
|
||||
HumanInputRecipient,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
WebAppRecipientPayload,
|
||||
)
|
||||
@ -38,14 +39,20 @@ from models.human_input import (
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _DeliveryAndRecipients:
|
||||
delivery: HumanInputDelivery
|
||||
recipients: Sequence[HumanInputRecipient]
|
||||
recipients: Sequence[HumanInputFormRecipient]
|
||||
|
||||
def webapp_recipient(self) -> HumanInputRecipient | None:
|
||||
def webapp_recipient(self) -> HumanInputFormRecipient | None:
|
||||
return next((i for i in self.recipients if i.recipient_type == RecipientType.WEBAPP), None)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _WorkspaceMemberInfo:
|
||||
user_id: str
|
||||
email: str
|
||||
|
||||
|
||||
class _HumanInputFormEntityImpl(HumanInputFormEntity):
|
||||
def __init__(self, form_model: HumanInputForm, web_app_recipient: HumanInputRecipient | None):
|
||||
def __init__(self, form_model: HumanInputForm, web_app_recipient: HumanInputFormRecipient | None):
|
||||
self._form_model = form_model
|
||||
self._web_app_recipient = web_app_recipient
|
||||
|
||||
@ -60,7 +67,7 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity):
|
||||
return self._web_app_recipient.access_token
|
||||
|
||||
|
||||
class _FormSubmissionEntityImpl(FormSubmissionEntity):
|
||||
class _FormSubmissionImpl(FormSubmission):
|
||||
def __init__(self, form_model: HumanInputForm):
|
||||
self._form_model = form_model
|
||||
|
||||
@ -78,37 +85,23 @@ class _FormSubmissionEntityImpl(FormSubmissionEntity):
|
||||
return json.loads(submitted_data)
|
||||
|
||||
|
||||
class WorkspaceMember:
|
||||
def user_id(self) -> str:
|
||||
pass
|
||||
|
||||
def email(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class WorkspaceMemberQueirer:
|
||||
def get_all_workspace_members(self) -> Sequence[WorkspaceMember]:
|
||||
# TOOD: need a way to query all members in the current workspace.
|
||||
pass
|
||||
|
||||
def get_members_by_ids(self, user_ids: Sequence[str]) -> Sequence[WorkspaceMember]:
|
||||
pass
|
||||
|
||||
|
||||
class HumanInputFormRepositoryImpl:
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
tenant_id: str,
|
||||
member_quierer: WorkspaceMemberQueirer,
|
||||
):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
self._tenant_id = tenant_id
|
||||
self._member_queirer = member_quierer
|
||||
|
||||
def _delivery_method_to_model(self, form_id, delivery_method: DeliveryChannelConfig) -> _DeliveryAndRecipients:
|
||||
def _delivery_method_to_model(
|
||||
self,
|
||||
session: Session,
|
||||
form_id: str,
|
||||
delivery_method: DeliveryChannelConfig,
|
||||
) -> _DeliveryAndRecipients:
|
||||
delivery_id = str(uuidv7())
|
||||
delivery_model = HumanInputDelivery(
|
||||
id=delivery_id,
|
||||
@ -117,9 +110,9 @@ class HumanInputFormRepositoryImpl:
|
||||
delivery_config_id=delivery_method.id,
|
||||
channel_payload=delivery_method.model_dump_json(),
|
||||
)
|
||||
recipients: list[HumanInputRecipient] = []
|
||||
recipients: list[HumanInputFormRecipient] = []
|
||||
if isinstance(delivery_method, WebAppDeliveryMethod):
|
||||
recipient_model = HumanInputRecipient(
|
||||
recipient_model = HumanInputFormRecipient(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
@ -129,11 +122,20 @@ class HumanInputFormRepositoryImpl:
|
||||
elif isinstance(delivery_method, EmailDeliveryMethod):
|
||||
email_recipients_config = delivery_method.config.recipients
|
||||
if email_recipients_config.whole_workspace:
|
||||
recipients.extend(self._create_whole_workspace_recipients(form_id=form_id, delivery_id=delivery_id))
|
||||
recipients.extend(
|
||||
self._create_whole_workspace_recipients(
|
||||
session=session,
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
recipients.extend(
|
||||
self._create_email_recipients(
|
||||
form_id=form_id, delivery_id=delivery_id, recipients=email_recipients_config.items
|
||||
session=session,
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipients=email_recipients_config.items,
|
||||
)
|
||||
)
|
||||
|
||||
@ -141,27 +143,34 @@ class HumanInputFormRepositoryImpl:
|
||||
|
||||
def _create_email_recipients(
|
||||
self,
|
||||
session: Session,
|
||||
form_id: str,
|
||||
delivery_id: str,
|
||||
recipients: Sequence[EmailRecipient],
|
||||
) -> list[HumanInputRecipient]:
|
||||
recipient_models: list[HumanInputRecipient] = []
|
||||
) -> list[HumanInputFormRecipient]:
|
||||
recipient_models: list[HumanInputFormRecipient] = []
|
||||
member_user_ids: list[str] = []
|
||||
for r in recipients:
|
||||
if isinstance(r, MemberRecipient):
|
||||
member_user_ids.append(r.user_id)
|
||||
elif isinstance(r, ExternalRecipient):
|
||||
recipient_model = HumanInputRecipient.new(
|
||||
recipient_model = HumanInputFormRecipient.new(
|
||||
form_id=form_id, delivery_id=delivery_id, payload=EmailExternalRecipientPayload(email=r.email)
|
||||
)
|
||||
recipient_models.append(recipient_model)
|
||||
else:
|
||||
raise AssertionError(f"unknown recipient type: recipient={r}")
|
||||
|
||||
members = self._member_queirer.get_members_by_ids(member_user_ids)
|
||||
for member in members:
|
||||
payload = EmailMemberRecipientPayload(user_id=member.user_id(), email=member.email())
|
||||
recipient_model = HumanInputRecipient.new(
|
||||
member_entries = {
|
||||
member.user_id: member.email
|
||||
for member in self._query_workspace_members(session=session, user_ids=member_user_ids)
|
||||
}
|
||||
for user_id in member_user_ids:
|
||||
email = member_entries.get(user_id)
|
||||
if email is None:
|
||||
continue
|
||||
payload = EmailMemberRecipientPayload(user_id=user_id, email=email)
|
||||
recipient_model = HumanInputFormRecipient.new(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
payload=payload,
|
||||
@ -169,12 +178,17 @@ class HumanInputFormRepositoryImpl:
|
||||
recipient_models.append(recipient_model)
|
||||
return recipient_models
|
||||
|
||||
def _create_whole_workspace_recipients(self, form_id: str, delivery_id: str) -> list[HumanInputRecipient]:
|
||||
def _create_whole_workspace_recipients(
|
||||
self,
|
||||
session: Session,
|
||||
form_id: str,
|
||||
delivery_id: str,
|
||||
) -> list[HumanInputFormRecipient]:
|
||||
recipeint_models = []
|
||||
members = self._member_queirer.get_all_workspace_members()
|
||||
members = self._query_workspace_members(session=session, user_ids=None)
|
||||
for member in members:
|
||||
payload = EmailMemberRecipientPayload(user_id=member.user_id(), email=member.email())
|
||||
recipient_model = HumanInputRecipient.new(
|
||||
payload = EmailMemberRecipientPayload(user_id=member.user_id, email=member.email)
|
||||
recipient_model = HumanInputFormRecipient.new(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
payload=payload,
|
||||
@ -183,6 +197,30 @@ class HumanInputFormRepositoryImpl:
|
||||
|
||||
return recipeint_models
|
||||
|
||||
def _query_workspace_members(
|
||||
self,
|
||||
session: Session,
|
||||
user_ids: Sequence[str] | None,
|
||||
) -> list[_WorkspaceMemberInfo]:
|
||||
unique_ids: set[str] | None
|
||||
if user_ids is None:
|
||||
unique_ids = None
|
||||
else:
|
||||
unique_ids = {user_id for user_id in user_ids if user_id}
|
||||
if not unique_ids:
|
||||
return []
|
||||
|
||||
stmt = (
|
||||
select(Account.id, Account.email)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
|
||||
.where(TenantAccountJoin.tenant_id == self._tenant_id)
|
||||
)
|
||||
if unique_ids is not None:
|
||||
stmt = stmt.where(Account.id.in_(unique_ids))
|
||||
|
||||
rows = session.execute(stmt).all()
|
||||
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
form_config: HumanInputNodeData = params.form_config
|
||||
|
||||
@ -190,8 +228,12 @@ class HumanInputFormRepositoryImpl:
|
||||
# Generate unique form ID
|
||||
form_id = str(uuidv7())
|
||||
form_definition = FormDefinition(
|
||||
form_content=form_config.form_content,
|
||||
inputs=form_config.inputs,
|
||||
user_actions=form_config.user_actions,
|
||||
rendered_content=params.rendered_content,
|
||||
timeout=form_config.timeout,
|
||||
timeout_unit=form_config.timeout_unit,
|
||||
)
|
||||
form_model = HumanInputForm(
|
||||
id=form_id,
|
||||
@ -203,9 +245,13 @@ class HumanInputFormRepositoryImpl:
|
||||
expiration_time=form_config.expiration_time(naive_utc_now()),
|
||||
)
|
||||
session.add(form_model)
|
||||
web_app_recipient: HumanInputRecipient | None = None
|
||||
web_app_recipient: HumanInputFormRecipient | None = None
|
||||
for delivery in form_config.delivery_methods:
|
||||
delivery_and_recipients = self._delivery_method_to_model(form_id, delivery)
|
||||
delivery_and_recipients = self._delivery_method_to_model(
|
||||
session=session,
|
||||
form_id=form_id,
|
||||
delivery_method=delivery,
|
||||
)
|
||||
session.add(delivery_and_recipients.delivery)
|
||||
session.add_all(delivery_and_recipients.recipients)
|
||||
if web_app_recipient is None:
|
||||
@ -214,7 +260,7 @@ class HumanInputFormRepositoryImpl:
|
||||
|
||||
return _HumanInputFormEntityImpl(form_model=form_model, web_app_recipient=web_app_recipient)
|
||||
|
||||
def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmissionEntity | None:
|
||||
def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmission | None:
|
||||
query = select(HumanInputForm).where(
|
||||
HumanInputForm.workflow_run_id == workflow_execution_id,
|
||||
HumanInputForm.node_id == node_id,
|
||||
@ -227,4 +273,13 @@ class HumanInputFormRepositoryImpl:
|
||||
if form_model.submitted_at is None:
|
||||
return None
|
||||
|
||||
return _FormSubmissionEntityImpl(form_model=form_model)
|
||||
return _FormSubmissionImpl(form_model=form_model)
|
||||
|
||||
def get_form_by_token(self, token: str, recipient_type: RecipientType | None = None):
|
||||
query = (
|
||||
select(HumanInputFormRecipient)
|
||||
.options(selectinload(HumanInputFormRecipient.form))
|
||||
.where()
|
||||
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
form_recipient = session.qu
|
||||
|
||||
@ -48,8 +48,26 @@ class HumanInputFormEntity(abc.ABC):
|
||||
# webapp delivery?
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def recipients(self) -> list["HumanInputFormRecipientEntity"]: ...
|
||||
|
||||
class FormSubmissionEntity(abc.ABC):
|
||||
|
||||
class HumanInputFormRecipientEntity(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def id(self) -> str:
|
||||
"""id returns the identifer of this recipient."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def token(self) -> str:
|
||||
"""token returns a random string used to submit form"""
|
||||
...
|
||||
|
||||
|
||||
class FormSubmission(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def selected_action_id(self) -> str:
|
||||
@ -81,7 +99,7 @@ class HumanInputFormRepository(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmissionEntity | None:
|
||||
def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmission | None:
|
||||
"""Retrieve the submission for a specific human input node.
|
||||
|
||||
Returns `FormSubmission` if the form has been submitted, or `None` if not.
|
||||
|
||||
@ -66,7 +66,7 @@ class HumanInputForm(DefaultFieldsMixin, Base):
|
||||
back_populates="form",
|
||||
lazy="raise",
|
||||
)
|
||||
completed_by_recipient: Mapped["HumanInputRecipient | None"] = relationship(
|
||||
completed_by_recipient: Mapped["HumanInputFormRecipient | None"] = relationship(
|
||||
"HumanInputRecipient",
|
||||
primaryjoin="HumanInputForm.completed_by_recipient_id == HumanInputRecipient.id",
|
||||
lazy="raise",
|
||||
@ -75,7 +75,7 @@ class HumanInputForm(DefaultFieldsMixin, Base):
|
||||
|
||||
|
||||
class HumanInputDelivery(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_deliveries"
|
||||
__tablename__ = "human_input_form_deliveries"
|
||||
|
||||
form_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
@ -94,10 +94,11 @@ class HumanInputDelivery(DefaultFieldsMixin, Base):
|
||||
back_populates="deliveries",
|
||||
lazy="raise",
|
||||
)
|
||||
recipients: Mapped[list["HumanInputRecipient"]] = relationship(
|
||||
recipients: Mapped[list["HumanInputFormRecipient"]] = relationship(
|
||||
"HumanInputRecipient",
|
||||
uselist=True,
|
||||
back_populates="delivery",
|
||||
cascade="all, delete-orphan",
|
||||
# Require explicit preloading
|
||||
lazy="raise",
|
||||
)
|
||||
|
||||
@ -135,8 +136,8 @@ RecipientPayload = Annotated[
|
||||
]
|
||||
|
||||
|
||||
class HumanInputRecipient(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_recipients"
|
||||
class HumanInputFormRecipient(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_form_recipients"
|
||||
|
||||
form_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
@ -158,7 +159,19 @@ class HumanInputRecipient(DefaultFieldsMixin, Base):
|
||||
|
||||
delivery: Mapped[HumanInputDelivery] = relationship(
|
||||
"HumanInputDelivery",
|
||||
uselist=False,
|
||||
foreign_keys=[delivery_id],
|
||||
back_populates="recipients",
|
||||
# Require explicit preloading
|
||||
lazy="raise",
|
||||
)
|
||||
|
||||
form: Mapped[HumanInputForm] = relationship(
|
||||
"HumanInputForm",
|
||||
uselist=False,
|
||||
foreign_keys=[form_id],
|
||||
back_populates="recipients",
|
||||
# Require explicit preloading
|
||||
lazy="raise",
|
||||
)
|
||||
|
||||
|
||||
@ -637,14 +637,14 @@ class WorkflowRun(Base):
|
||||
back_populates="workflow_run",
|
||||
)
|
||||
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
@property
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
def created_by_account(self):
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
|
||||
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
@property
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
def created_by_end_user(self):
|
||||
from .model import EndUser
|
||||
|
||||
@ -663,8 +663,8 @@ class WorkflowRun(Base):
|
||||
def outputs_dict(self) -> Mapping[str, Any]:
|
||||
return json.loads(self.outputs) if self.outputs else {}
|
||||
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
@property
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
def message(self):
|
||||
from .model import Message
|
||||
|
||||
@ -672,8 +672,8 @@ class WorkflowRun(Base):
|
||||
db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
|
||||
)
|
||||
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
@property
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
def workflow(self):
|
||||
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
||||
|
||||
|
||||
@ -1,16 +1,22 @@
|
||||
import abc
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.exception import BaseHTTPException
|
||||
from models.human_input import RecipientType
|
||||
from models.account import Account
|
||||
from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType
|
||||
|
||||
|
||||
class Form(abc.ABC):
|
||||
class Form:
|
||||
def __init__(self, form_model: HumanInputForm):
|
||||
self._form_model = form_model
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_definition(self) -> FormDefinition:
|
||||
pass
|
||||
@ -34,30 +40,102 @@ class FormSubmittedError(HumanInputError, BaseHTTPException):
|
||||
super().__init__(description=description)
|
||||
|
||||
|
||||
class FormNotFoundError(HumanInputError, BaseException):
|
||||
class FormNotFoundError(HumanInputError, BaseHTTPException):
|
||||
error_code = "human_input_form_not_found"
|
||||
code = 404
|
||||
|
||||
|
||||
class WebAppDeliveryNotEnabledError(HumanInputError, BaseException):
|
||||
pass
|
||||
|
||||
|
||||
class HumanInputService:
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
session_factory: sessionmaker[Session] | Engine,
|
||||
):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
|
||||
def get_form_definition_by_token(self, recipient_type: RecipientType, form_token: str) -> Form:
|
||||
pass
|
||||
def get_form_by_token(self, form_token: str) -> Form | None:
|
||||
query = (
|
||||
select(HumanInputFormRecipient)
|
||||
.options(selectinload(HumanInputFormRecipient.form))
|
||||
.where(HumanInputFormRecipient.access_token == form_token)
|
||||
)
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
recipient = session.scalars(query).first()
|
||||
if recipient is None:
|
||||
return None
|
||||
|
||||
def get_form_definition_by_id(self, form_id: str) -> Form | None:
|
||||
pass
|
||||
return Form(recipient.form)
|
||||
|
||||
def submit_form_by_id(self, form_id: str, selected_action_id: str, form_data: Mapping[str, Any]):
|
||||
pass
|
||||
def get_form_by_id(self, form_id: str) -> Form | None:
|
||||
query = select(HumanInputForm).where(HumanInputForm.id == form_id)
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
form_model = session.scalars(query).first()
|
||||
if form_model is None:
|
||||
return None
|
||||
|
||||
def submit_form_by_token(
|
||||
self, recipient_type: RecipientType, form_token: str, selected_action_id: str, form_data: Mapping[str, Any]
|
||||
):
|
||||
pass
|
||||
return Form(form_model)
|
||||
|
||||
def submit_form_by_id(self, form_id: str, user: Account, selected_action_id: str, form_data: Mapping[str, Any]):
|
||||
recipient_query = (
|
||||
select(HumanInputFormRecipient)
|
||||
.options(selectinload(HumanInputFormRecipient.form))
|
||||
.where(
|
||||
HumanInputFormRecipient.recipient_type == RecipientType.WEBAPP,
|
||||
HumanInputFormRecipient.form_id == form_id,
|
||||
)
|
||||
)
|
||||
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
recipient_model = session.scalars(recipient_query).first()
|
||||
|
||||
if recipient_model is None:
|
||||
raise WebAppDeliveryNotEnabledError()
|
||||
|
||||
form_model = recipient_model.form
|
||||
form = Form(form_model)
|
||||
if form.submitted:
|
||||
raise FormSubmittedError(form_model.id)
|
||||
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
form_model.selected_action_id = selected_action_id
|
||||
form_model.submitted_data = json.dumps(form_data)
|
||||
form_model.submitted_at = naive_utc_now()
|
||||
form_model.submission_user_id = user.id
|
||||
|
||||
form_model.completed_by_recipient_id = recipient_model.id
|
||||
session.add(form_model)
|
||||
# TODO: restart the execution of paused workflow
|
||||
|
||||
def submit_form_by_token(self, form_token: str, selected_action_id: str, form_data: Mapping[str, Any]):
|
||||
recipient_query = (
|
||||
select(HumanInputFormRecipient)
|
||||
.options(selectinload(HumanInputFormRecipient.form))
|
||||
.where(
|
||||
HumanInputFormRecipient.form_id == form_token,
|
||||
)
|
||||
)
|
||||
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
recipient_model = session.scalars(recipient_query).first()
|
||||
|
||||
if recipient_model is None:
|
||||
raise WebAppDeliveryNotEnabledError()
|
||||
|
||||
form_model = recipient_model.form
|
||||
form = Form(form_model)
|
||||
if form.submitted:
|
||||
raise FormSubmittedError(form_model.id)
|
||||
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
form_model.selected_action_id = selected_action_id
|
||||
form_model.submitted_data = json.dumps(form_data)
|
||||
form_model.submitted_at = naive_utc_now()
|
||||
form_model.submission_user_id = user.id
|
||||
|
||||
form_model.completed_by_recipient_id = recipient_model.id
|
||||
session.add(form_model)
|
||||
|
||||
@ -0,0 +1,158 @@
|
||||
"""TestContainers integration tests for HumanInputFormRepositoryImpl."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.repositories.human_input_reposotiry import HumanInputFormRepositoryImpl
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
UserAction,
|
||||
)
|
||||
from core.workflow.repositories.human_input_form_repository import FormCreateParams
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.human_input import (
|
||||
EmailExternalRecipientPayload,
|
||||
EmailMemberRecipientPayload,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
)
|
||||
|
||||
|
||||
def _create_tenant_with_members(session: Session, member_emails: list[str]) -> tuple[Tenant, list[Account]]:
|
||||
tenant = Tenant(name="Test Tenant", status="normal")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
members: list[Account] = []
|
||||
for index, email in enumerate(member_emails):
|
||||
account = Account(
|
||||
email=email,
|
||||
name=f"Member {index}",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
session.add(account)
|
||||
session.flush()
|
||||
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.NORMAL,
|
||||
current=True,
|
||||
)
|
||||
session.add(tenant_join)
|
||||
members.append(account)
|
||||
|
||||
session.commit()
|
||||
return tenant, members
|
||||
|
||||
|
||||
def _build_form_params(delivery_methods: list[EmailDeliveryMethod]) -> FormCreateParams:
|
||||
form_config = HumanInputNodeData(
|
||||
title="Human Approval",
|
||||
delivery_methods=delivery_methods,
|
||||
form_content="<p>Approve?</p>",
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
)
|
||||
return FormCreateParams(
|
||||
workflow_execution_id=str(uuid4()),
|
||||
node_id="human-input-node",
|
||||
form_config=form_config,
|
||||
rendered_content="<p>Approve?</p>",
|
||||
)
|
||||
|
||||
|
||||
def _build_email_delivery(
|
||||
whole_workspace: bool, recipients: list[MemberRecipient | ExternalRecipient]
|
||||
) -> EmailDeliveryMethod:
|
||||
return EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(whole_workspace=whole_workspace, items=recipients),
|
||||
subject="Approval Needed",
|
||||
body="Please review",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestHumanInputFormRepositoryImplWithContainers:
|
||||
def test_create_form_with_whole_workspace_recipients(self, db_session_with_containers: Session) -> None:
|
||||
engine = db_session_with_containers.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
tenant, members = _create_tenant_with_members(
|
||||
db_session_with_containers,
|
||||
member_emails=["member1@example.com", "member2@example.com"],
|
||||
)
|
||||
|
||||
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
|
||||
params = _build_form_params(
|
||||
delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])],
|
||||
)
|
||||
|
||||
form_entity = repository.create_form(params)
|
||||
|
||||
with Session(engine) as verification_session:
|
||||
recipients = verification_session.scalars(
|
||||
select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id)
|
||||
).all()
|
||||
|
||||
assert len(recipients) == len(members)
|
||||
member_payloads = [
|
||||
EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload)
|
||||
for recipient in recipients
|
||||
if recipient.recipient_type == RecipientType.EMAIL_MEMBER
|
||||
]
|
||||
member_emails = {payload.email for payload in member_payloads}
|
||||
assert member_emails == {member.email for member in members}
|
||||
|
||||
def test_create_form_with_specific_members_and_external(self, db_session_with_containers: Session) -> None:
|
||||
engine = db_session_with_containers.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
tenant, members = _create_tenant_with_members(
|
||||
db_session_with_containers,
|
||||
member_emails=["primary@example.com", "secondary@example.com"],
|
||||
)
|
||||
|
||||
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
|
||||
params = _build_form_params(
|
||||
delivery_methods=[
|
||||
_build_email_delivery(
|
||||
whole_workspace=False,
|
||||
recipients=[
|
||||
MemberRecipient(user_id=members[0].id),
|
||||
ExternalRecipient(email="external@example.com"),
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
form_entity = repository.create_form(params)
|
||||
|
||||
with Session(engine) as verification_session:
|
||||
recipients = verification_session.scalars(
|
||||
select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id)
|
||||
).all()
|
||||
|
||||
member_recipient_payloads = [
|
||||
EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload)
|
||||
for recipient in recipients
|
||||
if recipient.recipient_type == RecipientType.EMAIL_MEMBER
|
||||
]
|
||||
assert len(member_recipient_payloads) == 1
|
||||
assert member_recipient_payloads[0].user_id == members[0].id
|
||||
|
||||
external_payloads = [
|
||||
EmailExternalRecipientPayload.model_validate_json(recipient.recipient_payload)
|
||||
for recipient in recipients
|
||||
if recipient.recipient_type == RecipientType.EMAIL_EXTERNAL
|
||||
]
|
||||
assert len(external_payloads) == 1
|
||||
assert external_payloads[0].email == "external@example.com"
|
||||
@ -0,0 +1,127 @@
|
||||
"""Unit tests for HumanInputFormRepositoryImpl private helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.human_input_reposotiry import (
|
||||
HumanInputFormRepositoryImpl,
|
||||
_WorkspaceMemberInfo,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import ExternalRecipient, MemberRecipient
|
||||
from models.human_input import (
|
||||
EmailExternalRecipientPayload,
|
||||
EmailMemberRecipientPayload,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
)
|
||||
|
||||
|
||||
def _build_repository() -> HumanInputFormRepositoryImpl:
|
||||
return HumanInputFormRepositoryImpl(session_factory=MagicMock(), tenant_id="tenant-id")
|
||||
|
||||
|
||||
def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleNamespace]:
|
||||
created: list[SimpleNamespace] = []
|
||||
|
||||
def fake_new(cls, form_id: str, delivery_id: str, payload): # type: ignore[no-untyped-def]
|
||||
recipient = SimpleNamespace(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipient_type=payload.TYPE,
|
||||
recipient_payload=payload.model_dump_json(),
|
||||
)
|
||||
created.append(recipient)
|
||||
return recipient
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRecipient, "new", classmethod(fake_new))
|
||||
return created
|
||||
|
||||
|
||||
class TestHumanInputFormRepositoryImplHelpers:
|
||||
def test_create_email_recipients_with_member_and_external(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = _build_repository()
|
||||
session_stub = object()
|
||||
_patch_recipient_factory(monkeypatch)
|
||||
|
||||
def fake_query(self, session, user_ids): # type: ignore[no-untyped-def]
|
||||
assert session is session_stub
|
||||
assert user_ids == ["member-1"]
|
||||
return [_WorkspaceMemberInfo(user_id="member-1", email="member@example.com")]
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members", fake_query)
|
||||
|
||||
recipients = repo._create_email_recipients(
|
||||
session=session_stub,
|
||||
form_id="form-id",
|
||||
delivery_id="delivery-id",
|
||||
recipients=[
|
||||
MemberRecipient(user_id="member-1"),
|
||||
ExternalRecipient(email="external@example.com"),
|
||||
],
|
||||
)
|
||||
|
||||
assert len(recipients) == 2
|
||||
member_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_MEMBER)
|
||||
external_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL)
|
||||
|
||||
member_payload = EmailMemberRecipientPayload.model_validate_json(member_recipient.recipient_payload)
|
||||
assert member_payload.user_id == "member-1"
|
||||
assert member_payload.email == "member@example.com"
|
||||
|
||||
external_payload = EmailExternalRecipientPayload.model_validate_json(external_recipient.recipient_payload)
|
||||
assert external_payload.email == "external@example.com"
|
||||
|
||||
def test_create_email_recipients_skips_unknown_members(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = _build_repository()
|
||||
session_stub = object()
|
||||
created = _patch_recipient_factory(monkeypatch)
|
||||
|
||||
def fake_query(self, session, user_ids): # type: ignore[no-untyped-def]
|
||||
assert session is session_stub
|
||||
assert user_ids == ["missing-member"]
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members", fake_query)
|
||||
|
||||
recipients = repo._create_email_recipients(
|
||||
session=session_stub,
|
||||
form_id="form-id",
|
||||
delivery_id="delivery-id",
|
||||
recipients=[
|
||||
MemberRecipient(user_id="missing-member"),
|
||||
ExternalRecipient(email="external@example.com"),
|
||||
],
|
||||
)
|
||||
|
||||
assert len(recipients) == 1
|
||||
assert recipients[0].recipient_type == RecipientType.EMAIL_EXTERNAL
|
||||
assert len(created) == 1 # only external recipient created via factory
|
||||
|
||||
def test_create_whole_workspace_recipients_uses_all_members(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = _build_repository()
|
||||
session_stub = object()
|
||||
_patch_recipient_factory(monkeypatch)
|
||||
|
||||
def fake_query(self, session, user_ids): # type: ignore[no-untyped-def]
|
||||
assert session is session_stub
|
||||
assert user_ids is None
|
||||
return [
|
||||
_WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"),
|
||||
_WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members", fake_query)
|
||||
|
||||
recipients = repo._create_whole_workspace_recipients(
|
||||
session=session_stub,
|
||||
form_id="form-id",
|
||||
delivery_id="delivery-id",
|
||||
)
|
||||
|
||||
assert len(recipients) == 2
|
||||
emails = {EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email for r in recipients}
|
||||
assert emails == {"member1@example.com", "member2@example.com"}
|
||||
Loading…
Reference in New Issue
Block a user