diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 294d1f7bcc..0fe7148cf1 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -34,6 +34,7 @@ logger = logging.getLogger(__name__) class _FormDefinitionWithSite(FormDefinition): + # the site field may be not necessary for console scenario. site: None diff --git a/api/controllers/web/human_input_form.py b/api/controllers/web/human_input_form.py index 3c7ea428e4..1e22ae94ce 100644 --- a/api/controllers/web/human_input_form.py +++ b/api/controllers/web/human_input_form.py @@ -4,55 +4,47 @@ Web App Human Input Form APIs. import logging -from flask import jsonify -from flask_restful import reqparse +from flask import Response +from flask_restx import reqparse -from controllers.web import api -from controllers.web.error import ( - NotFoundError, -) +from controllers.web import web_ns +from controllers.web.error import NotFoundError from controllers.web.wraps import WebApiResource from extensions.ext_database import db -from models.human_input import HumanInputSubmissionType +from models.human_input import RecipientType +from models.model import App, EndUser +from services.human_input_service import Form, FormNotFoundError, HumanInputService logger = logging.getLogger(__name__) -class HumanInputFormApi(WebApiResource): - """API for getting human input form definition.""" +def _jsonify_form_definition(form: Form) -> Response: + """Return the Pydantic definition as a JSON response.""" + return Response(form.get_definition().model_dump_json(), mimetype="application/json") - def get(self, web_app_form_token: str): + +@web_ns.route("/form/human_input/") +class HumanInputFormApi(WebApiResource): + """API for getting and submitting human input forms via the web app.""" + + def get(self, _app_model: App, _end_user: EndUser, web_app_form_token: str): """ Get human input form definition by token. GET /api/form/human_input/ """ + service = HumanInputService(db.engine) try: - service = HumanInputFormService(db.session()) - form_definition = service.get_form_definition( - identifier=web_app_form_token, is_token=True, include_site_info=True - ) - return form_definition, 200 - - except HumanInputFormNotFoundError: + form = service.get_form_definition_by_token(RecipientType.WEBAPP, web_app_form_token) + except FormNotFoundError: raise NotFoundError("Form not found") - except HumanInputFormExpiredError: - return jsonify( - {"error_code": "human_input_form_expired", "description": "Human input form has expired"} - ), 400 - except HumanInputFormAlreadySubmittedError: - return jsonify( - { - "error_code": "human_input_form_submitted", - "description": "Human input form has already been submitted", - } - ), 400 + if form is None: + raise NotFoundError("Form not found") -class HumanInputFormSubmissionApi(WebApiResource): - """API for submitting human input forms.""" + return _jsonify_form_definition(form) - def post(self, web_app_form_token: str): + def post(self, _app_model: App, _end_user: EndUser, web_app_form_token: str): """ Submit human input form by token. @@ -71,36 +63,15 @@ class HumanInputFormSubmissionApi(WebApiResource): parser.add_argument("action", type=str, required=True, location="json") args = parser.parse_args() + service = HumanInputService(db.engine) try: - # Submit the form - service = HumanInputFormService(db.session()) - service.submit_form( - identifier=web_app_form_token, + service.submit_form_by_token( + recipient_type=RecipientType.WEBAPP, + form_token=web_app_form_token, + selected_action_id=args["action"], form_data=args["inputs"], - action=args["action"], - is_token=True, - submission_type=HumanInputSubmissionType.web_app, ) - - return {}, 200 - - except HumanInputFormNotFoundError: + except FormNotFoundError: raise NotFoundError("Form not found") - except HumanInputFormExpiredError: - return jsonify( - {"error_code": "human_input_form_expired", "description": "Human input form has expired"} - ), 400 - except HumanInputFormAlreadySubmittedError: - return jsonify( - { - "error_code": "human_input_form_submitted", - "description": "Human input form has already been submitted", - } - ), 400 - except InvalidFormDataError as e: - return jsonify({"error_code": "invalid_form_data", "description": e.message}), 400 - -# Register the APIs -api.add_resource(HumanInputFormApi, "/form/human_input/") -api.add_resource(HumanInputFormSubmissionApi, "/form/human_input/", methods=["POST"]) + return {}, 200 diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index d83823d7b9..6c01056c4d 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -1,19 +1,28 @@ -""" -Repository implementations for data access. +"""Repository implementations for data access.""" -This package contains concrete implementations of the repository interfaces -defined in the core.workflow.repository package. -""" +from __future__ import annotations -from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository -from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError -from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from importlib import import_module +from typing import Any -__all__ = [ - "CeleryWorkflowExecutionRepository", - "CeleryWorkflowNodeExecutionRepository", - "DifyCoreRepositoryFactory", - "RepositoryImportError", - "SQLAlchemyWorkflowNodeExecutionRepository", -] +_ATTRIBUTE_MODULE_MAP = { + "CeleryWorkflowExecutionRepository": "core.repositories.celery_workflow_execution_repository", + "CeleryWorkflowNodeExecutionRepository": "core.repositories.celery_workflow_node_execution_repository", + "DifyCoreRepositoryFactory": "core.repositories.factory", + "RepositoryImportError": "core.repositories.factory", + "SQLAlchemyWorkflowNodeExecutionRepository": "core.repositories.sqlalchemy_workflow_node_execution_repository", +} + +__all__ = list(_ATTRIBUTE_MODULE_MAP.keys()) + + +def __getattr__(name: str) -> Any: + module_path = _ATTRIBUTE_MODULE_MAP.get(name) + if module_path is None: + raise AttributeError(f"module 'core.repositories' has no attribute '{name}'") + module = import_module(module_path) + return getattr(module, name) + + +def __dir__() -> list[str]: # pragma: no cover - simple helper + return sorted(__all__) diff --git a/api/core/repositories/human_input_reposotiry.py b/api/core/repositories/human_input_reposotiry.py index 3513a5da99..4fc84664dc 100644 --- a/api/core/repositories/human_input_reposotiry.py +++ b/api/core/repositories/human_input_reposotiry.py @@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence from typing import Any from sqlalchemy import Engine, select -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker, selectinload from core.workflow.nodes.human_input.entities import ( DeliveryChannelConfig, @@ -19,17 +19,18 @@ from core.workflow.nodes.human_input.entities import ( from core.workflow.repositories.human_input_form_repository import ( FormCreateParams, FormNotFoundError, - FormSubmissionEntity, + FormSubmission, HumanInputFormEntity, ) from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 +from models.account import Account, TenantAccountJoin from models.human_input import ( EmailExternalRecipientPayload, EmailMemberRecipientPayload, HumanInputDelivery, HumanInputForm, - HumanInputRecipient, + HumanInputFormRecipient, RecipientType, WebAppRecipientPayload, ) @@ -38,14 +39,20 @@ from models.human_input import ( @dataclasses.dataclass(frozen=True) class _DeliveryAndRecipients: delivery: HumanInputDelivery - recipients: Sequence[HumanInputRecipient] + recipients: Sequence[HumanInputFormRecipient] - def webapp_recipient(self) -> HumanInputRecipient | None: + def webapp_recipient(self) -> HumanInputFormRecipient | None: return next((i for i in self.recipients if i.recipient_type == RecipientType.WEBAPP), None) +@dataclasses.dataclass(frozen=True) +class _WorkspaceMemberInfo: + user_id: str + email: str + + class _HumanInputFormEntityImpl(HumanInputFormEntity): - def __init__(self, form_model: HumanInputForm, web_app_recipient: HumanInputRecipient | None): + def __init__(self, form_model: HumanInputForm, web_app_recipient: HumanInputFormRecipient | None): self._form_model = form_model self._web_app_recipient = web_app_recipient @@ -60,7 +67,7 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity): return self._web_app_recipient.access_token -class _FormSubmissionEntityImpl(FormSubmissionEntity): +class _FormSubmissionImpl(FormSubmission): def __init__(self, form_model: HumanInputForm): self._form_model = form_model @@ -78,37 +85,23 @@ class _FormSubmissionEntityImpl(FormSubmissionEntity): return json.loads(submitted_data) -class WorkspaceMember: - def user_id(self) -> str: - pass - - def email(self) -> str: - pass - - -class WorkspaceMemberQueirer: - def get_all_workspace_members(self) -> Sequence[WorkspaceMember]: - # TOOD: need a way to query all members in the current workspace. - pass - - def get_members_by_ids(self, user_ids: Sequence[str]) -> Sequence[WorkspaceMember]: - pass - - class HumanInputFormRepositoryImpl: def __init__( self, session_factory: sessionmaker | Engine, tenant_id: str, - member_quierer: WorkspaceMemberQueirer, ): if isinstance(session_factory, Engine): session_factory = sessionmaker(bind=session_factory) self._session_factory = session_factory self._tenant_id = tenant_id - self._member_queirer = member_quierer - def _delivery_method_to_model(self, form_id, delivery_method: DeliveryChannelConfig) -> _DeliveryAndRecipients: + def _delivery_method_to_model( + self, + session: Session, + form_id: str, + delivery_method: DeliveryChannelConfig, + ) -> _DeliveryAndRecipients: delivery_id = str(uuidv7()) delivery_model = HumanInputDelivery( id=delivery_id, @@ -117,9 +110,9 @@ class HumanInputFormRepositoryImpl: delivery_config_id=delivery_method.id, channel_payload=delivery_method.model_dump_json(), ) - recipients: list[HumanInputRecipient] = [] + recipients: list[HumanInputFormRecipient] = [] if isinstance(delivery_method, WebAppDeliveryMethod): - recipient_model = HumanInputRecipient( + recipient_model = HumanInputFormRecipient( form_id=form_id, delivery_id=delivery_id, recipient_type=RecipientType.WEBAPP, @@ -129,11 +122,20 @@ class HumanInputFormRepositoryImpl: elif isinstance(delivery_method, EmailDeliveryMethod): email_recipients_config = delivery_method.config.recipients if email_recipients_config.whole_workspace: - recipients.extend(self._create_whole_workspace_recipients(form_id=form_id, delivery_id=delivery_id)) + recipients.extend( + self._create_whole_workspace_recipients( + session=session, + form_id=form_id, + delivery_id=delivery_id, + ) + ) else: recipients.extend( self._create_email_recipients( - form_id=form_id, delivery_id=delivery_id, recipients=email_recipients_config.items + session=session, + form_id=form_id, + delivery_id=delivery_id, + recipients=email_recipients_config.items, ) ) @@ -141,27 +143,34 @@ class HumanInputFormRepositoryImpl: def _create_email_recipients( self, + session: Session, form_id: str, delivery_id: str, recipients: Sequence[EmailRecipient], - ) -> list[HumanInputRecipient]: - recipient_models: list[HumanInputRecipient] = [] + ) -> list[HumanInputFormRecipient]: + recipient_models: list[HumanInputFormRecipient] = [] member_user_ids: list[str] = [] for r in recipients: if isinstance(r, MemberRecipient): member_user_ids.append(r.user_id) elif isinstance(r, ExternalRecipient): - recipient_model = HumanInputRecipient.new( + recipient_model = HumanInputFormRecipient.new( form_id=form_id, delivery_id=delivery_id, payload=EmailExternalRecipientPayload(email=r.email) ) recipient_models.append(recipient_model) else: raise AssertionError(f"unknown recipient type: recipient={r}") - members = self._member_queirer.get_members_by_ids(member_user_ids) - for member in members: - payload = EmailMemberRecipientPayload(user_id=member.user_id(), email=member.email()) - recipient_model = HumanInputRecipient.new( + member_entries = { + member.user_id: member.email + for member in self._query_workspace_members(session=session, user_ids=member_user_ids) + } + for user_id in member_user_ids: + email = member_entries.get(user_id) + if email is None: + continue + payload = EmailMemberRecipientPayload(user_id=user_id, email=email) + recipient_model = HumanInputFormRecipient.new( form_id=form_id, delivery_id=delivery_id, payload=payload, @@ -169,12 +178,17 @@ class HumanInputFormRepositoryImpl: recipient_models.append(recipient_model) return recipient_models - def _create_whole_workspace_recipients(self, form_id: str, delivery_id: str) -> list[HumanInputRecipient]: + def _create_whole_workspace_recipients( + self, + session: Session, + form_id: str, + delivery_id: str, + ) -> list[HumanInputFormRecipient]: recipeint_models = [] - members = self._member_queirer.get_all_workspace_members() + members = self._query_workspace_members(session=session, user_ids=None) for member in members: - payload = EmailMemberRecipientPayload(user_id=member.user_id(), email=member.email()) - recipient_model = HumanInputRecipient.new( + payload = EmailMemberRecipientPayload(user_id=member.user_id, email=member.email) + recipient_model = HumanInputFormRecipient.new( form_id=form_id, delivery_id=delivery_id, payload=payload, @@ -183,6 +197,30 @@ class HumanInputFormRepositoryImpl: return recipeint_models + def _query_workspace_members( + self, + session: Session, + user_ids: Sequence[str] | None, + ) -> list[_WorkspaceMemberInfo]: + unique_ids: set[str] | None + if user_ids is None: + unique_ids = None + else: + unique_ids = {user_id for user_id in user_ids if user_id} + if not unique_ids: + return [] + + stmt = ( + select(Account.id, Account.email) + .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) + .where(TenantAccountJoin.tenant_id == self._tenant_id) + ) + if unique_ids is not None: + stmt = stmt.where(Account.id.in_(unique_ids)) + + rows = session.execute(stmt).all() + return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows] + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: form_config: HumanInputNodeData = params.form_config @@ -190,8 +228,12 @@ class HumanInputFormRepositoryImpl: # Generate unique form ID form_id = str(uuidv7()) form_definition = FormDefinition( + form_content=form_config.form_content, inputs=form_config.inputs, user_actions=form_config.user_actions, + rendered_content=params.rendered_content, + timeout=form_config.timeout, + timeout_unit=form_config.timeout_unit, ) form_model = HumanInputForm( id=form_id, @@ -203,9 +245,13 @@ class HumanInputFormRepositoryImpl: expiration_time=form_config.expiration_time(naive_utc_now()), ) session.add(form_model) - web_app_recipient: HumanInputRecipient | None = None + web_app_recipient: HumanInputFormRecipient | None = None for delivery in form_config.delivery_methods: - delivery_and_recipients = self._delivery_method_to_model(form_id, delivery) + delivery_and_recipients = self._delivery_method_to_model( + session=session, + form_id=form_id, + delivery_method=delivery, + ) session.add(delivery_and_recipients.delivery) session.add_all(delivery_and_recipients.recipients) if web_app_recipient is None: @@ -214,7 +260,7 @@ class HumanInputFormRepositoryImpl: return _HumanInputFormEntityImpl(form_model=form_model, web_app_recipient=web_app_recipient) - def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmissionEntity | None: + def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmission | None: query = select(HumanInputForm).where( HumanInputForm.workflow_run_id == workflow_execution_id, HumanInputForm.node_id == node_id, @@ -227,4 +273,13 @@ class HumanInputFormRepositoryImpl: if form_model.submitted_at is None: return None - return _FormSubmissionEntityImpl(form_model=form_model) + return _FormSubmissionImpl(form_model=form_model) + + def get_form_by_token(self, token: str, recipient_type: RecipientType | None = None): + query = ( + select(HumanInputFormRecipient) + .options(selectinload(HumanInputFormRecipient.form)) + .where() + + with self._session_factory(expire_on_commit=False) as session: + form_recipient = session.qu diff --git a/api/core/workflow/repositories/human_input_form_repository.py b/api/core/workflow/repositories/human_input_form_repository.py index d98a4a034c..8fd33086f4 100644 --- a/api/core/workflow/repositories/human_input_form_repository.py +++ b/api/core/workflow/repositories/human_input_form_repository.py @@ -48,8 +48,26 @@ class HumanInputFormEntity(abc.ABC): # webapp delivery? pass + @property + @abc.abstractmethod + def recipients(self) -> list["HumanInputFormRecipientEntity"]: ... -class FormSubmissionEntity(abc.ABC): + +class HumanInputFormRecipientEntity(abc.ABC): + @property + @abc.abstractmethod + def id(self) -> str: + """id returns the identifer of this recipient.""" + ... + + @property + @abc.abstractmethod + def token(self) -> str: + """token returns a random string used to submit form""" + ... + + +class FormSubmission(abc.ABC): @property @abc.abstractmethod def selected_action_id(self) -> str: @@ -81,7 +99,7 @@ class HumanInputFormRepository(Protocol): """ ... - def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmissionEntity | None: + def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmission | None: """Retrieve the submission for a specific human input node. Returns `FormSubmission` if the form has been submitted, or `None` if not. diff --git a/api/models/human_input.py b/api/models/human_input.py index 23e71778b3..e5b2d332cb 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -66,7 +66,7 @@ class HumanInputForm(DefaultFieldsMixin, Base): back_populates="form", lazy="raise", ) - completed_by_recipient: Mapped["HumanInputRecipient | None"] = relationship( + completed_by_recipient: Mapped["HumanInputFormRecipient | None"] = relationship( "HumanInputRecipient", primaryjoin="HumanInputForm.completed_by_recipient_id == HumanInputRecipient.id", lazy="raise", @@ -75,7 +75,7 @@ class HumanInputForm(DefaultFieldsMixin, Base): class HumanInputDelivery(DefaultFieldsMixin, Base): - __tablename__ = "human_input_deliveries" + __tablename__ = "human_input_form_deliveries" form_id: Mapped[str] = mapped_column( StringUUID, @@ -94,10 +94,11 @@ class HumanInputDelivery(DefaultFieldsMixin, Base): back_populates="deliveries", lazy="raise", ) - recipients: Mapped[list["HumanInputRecipient"]] = relationship( + recipients: Mapped[list["HumanInputFormRecipient"]] = relationship( "HumanInputRecipient", + uselist=True, back_populates="delivery", - cascade="all, delete-orphan", + # Require explicit preloading lazy="raise", ) @@ -135,8 +136,8 @@ RecipientPayload = Annotated[ ] -class HumanInputRecipient(DefaultFieldsMixin, Base): - __tablename__ = "human_input_recipients" +class HumanInputFormRecipient(DefaultFieldsMixin, Base): + __tablename__ = "human_input_form_recipients" form_id: Mapped[str] = mapped_column( StringUUID, @@ -158,7 +159,19 @@ class HumanInputRecipient(DefaultFieldsMixin, Base): delivery: Mapped[HumanInputDelivery] = relationship( "HumanInputDelivery", + uselist=False, + foreign_keys=[delivery_id], back_populates="recipients", + # Require explicit preloading + lazy="raise", + ) + + form: Mapped[HumanInputForm] = relationship( + "HumanInputForm", + uselist=False, + foreign_keys=[form_id], + back_populates="recipients", + # Require explicit preloading lazy="raise", ) diff --git a/api/models/workflow.py b/api/models/workflow.py index 6788449680..c7a206cbb7 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -637,14 +637,14 @@ class WorkflowRun(Base): back_populates="workflow_run", ) - @deprecated("This method is retained for historical reasons; avoid using it if possible.") @property + @deprecated("This method is retained for historical reasons; avoid using it if possible.") def created_by_account(self): created_by_role = CreatorUserRole(self.created_by_role) return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None - @deprecated("This method is retained for historical reasons; avoid using it if possible.") @property + @deprecated("This method is retained for historical reasons; avoid using it if possible.") def created_by_end_user(self): from .model import EndUser @@ -663,8 +663,8 @@ class WorkflowRun(Base): def outputs_dict(self) -> Mapping[str, Any]: return json.loads(self.outputs) if self.outputs else {} - @deprecated("This method is retained for historical reasons; avoid using it if possible.") @property + @deprecated("This method is retained for historical reasons; avoid using it if possible.") def message(self): from .model import Message @@ -672,8 +672,8 @@ class WorkflowRun(Base): db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first() ) - @deprecated("This method is retained for historical reasons; avoid using it if possible.") @property + @deprecated("This method is retained for historical reasons; avoid using it if possible.") def workflow(self): return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first() diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index aca1058cf9..67c4f25376 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -1,16 +1,22 @@ import abc +import json from collections.abc import Mapping from typing import Any -from sqlalchemy import Engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy import Engine, select +from sqlalchemy.orm import Session, selectinload, sessionmaker from core.workflow.nodes.human_input.entities import FormDefinition +from libs.datetime_utils import naive_utc_now from libs.exception import BaseHTTPException -from models.human_input import RecipientType +from models.account import Account +from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType -class Form(abc.ABC): +class Form: + def __init__(self, form_model: HumanInputForm): + self._form_model = form_model + @abc.abstractmethod def get_definition(self) -> FormDefinition: pass @@ -34,30 +40,102 @@ class FormSubmittedError(HumanInputError, BaseHTTPException): super().__init__(description=description) -class FormNotFoundError(HumanInputError, BaseException): +class FormNotFoundError(HumanInputError, BaseHTTPException): error_code = "human_input_form_not_found" code = 404 +class WebAppDeliveryNotEnabledError(HumanInputError, BaseException): + pass + + class HumanInputService: def __init__( self, - session_factory: sessionmaker | Engine, + session_factory: sessionmaker[Session] | Engine, ): if isinstance(session_factory, Engine): session_factory = sessionmaker(bind=session_factory) self._session_factory = session_factory - def get_form_definition_by_token(self, recipient_type: RecipientType, form_token: str) -> Form: - pass + def get_form_by_token(self, form_token: str) -> Form | None: + query = ( + select(HumanInputFormRecipient) + .options(selectinload(HumanInputFormRecipient.form)) + .where(HumanInputFormRecipient.access_token == form_token) + ) + with self._session_factory(expire_on_commit=False) as session: + recipient = session.scalars(query).first() + if recipient is None: + return None - def get_form_definition_by_id(self, form_id: str) -> Form | None: - pass + return Form(recipient.form) - def submit_form_by_id(self, form_id: str, selected_action_id: str, form_data: Mapping[str, Any]): - pass + def get_form_by_id(self, form_id: str) -> Form | None: + query = select(HumanInputForm).where(HumanInputForm.id == form_id) + with self._session_factory(expire_on_commit=False) as session: + form_model = session.scalars(query).first() + if form_model is None: + return None - def submit_form_by_token( - self, recipient_type: RecipientType, form_token: str, selected_action_id: str, form_data: Mapping[str, Any] - ): - pass + return Form(form_model) + + def submit_form_by_id(self, form_id: str, user: Account, selected_action_id: str, form_data: Mapping[str, Any]): + recipient_query = ( + select(HumanInputFormRecipient) + .options(selectinload(HumanInputFormRecipient.form)) + .where( + HumanInputFormRecipient.recipient_type == RecipientType.WEBAPP, + HumanInputFormRecipient.form_id == form_id, + ) + ) + + with self._session_factory(expire_on_commit=False) as session: + recipient_model = session.scalars(recipient_query).first() + + if recipient_model is None: + raise WebAppDeliveryNotEnabledError() + + form_model = recipient_model.form + form = Form(form_model) + if form.submitted: + raise FormSubmittedError(form_model.id) + + with self._session_factory(expire_on_commit=False) as session, session.begin(): + form_model.selected_action_id = selected_action_id + form_model.submitted_data = json.dumps(form_data) + form_model.submitted_at = naive_utc_now() + form_model.submission_user_id = user.id + + form_model.completed_by_recipient_id = recipient_model.id + session.add(form_model) + # TODO: restart the execution of paused workflow + + def submit_form_by_token(self, form_token: str, selected_action_id: str, form_data: Mapping[str, Any]): + recipient_query = ( + select(HumanInputFormRecipient) + .options(selectinload(HumanInputFormRecipient.form)) + .where( + HumanInputFormRecipient.form_id == form_token, + ) + ) + + with self._session_factory(expire_on_commit=False) as session: + recipient_model = session.scalars(recipient_query).first() + + if recipient_model is None: + raise WebAppDeliveryNotEnabledError() + + form_model = recipient_model.form + form = Form(form_model) + if form.submitted: + raise FormSubmittedError(form_model.id) + + with self._session_factory(expire_on_commit=False) as session, session.begin(): + form_model.selected_action_id = selected_action_id + form_model.submitted_data = json.dumps(form_data) + form_model.submitted_at = naive_utc_now() + form_model.submission_user_id = user.id + + form_model.completed_by_recipient_id = recipient_model.id + session.add(form_model) diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py new file mode 100644 index 0000000000..6731cd0a11 --- /dev/null +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -0,0 +1,158 @@ +"""TestContainers integration tests for HumanInputFormRepositoryImpl.""" + +from __future__ import annotations + +from uuid import uuid4 + +from sqlalchemy import Engine, select +from sqlalchemy.orm import Session + +from core.repositories.human_input_reposotiry import HumanInputFormRepositoryImpl +from core.workflow.nodes.human_input.entities import ( + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, + HumanInputNodeData, + MemberRecipient, + UserAction, +) +from core.workflow.repositories.human_input_form_repository import FormCreateParams +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.human_input import ( + EmailExternalRecipientPayload, + EmailMemberRecipientPayload, + HumanInputFormRecipient, + RecipientType, +) + + +def _create_tenant_with_members(session: Session, member_emails: list[str]) -> tuple[Tenant, list[Account]]: + tenant = Tenant(name="Test Tenant", status="normal") + session.add(tenant) + session.flush() + + members: list[Account] = [] + for index, email in enumerate(member_emails): + account = Account( + email=email, + name=f"Member {index}", + interface_language="en-US", + status="active", + ) + session.add(account) + session.flush() + + tenant_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.NORMAL, + current=True, + ) + session.add(tenant_join) + members.append(account) + + session.commit() + return tenant, members + + +def _build_form_params(delivery_methods: list[EmailDeliveryMethod]) -> FormCreateParams: + form_config = HumanInputNodeData( + title="Human Approval", + delivery_methods=delivery_methods, + form_content="

Approve?

", + user_actions=[UserAction(id="approve", title="Approve")], + ) + return FormCreateParams( + workflow_execution_id=str(uuid4()), + node_id="human-input-node", + form_config=form_config, + rendered_content="

Approve?

", + ) + + +def _build_email_delivery( + whole_workspace: bool, recipients: list[MemberRecipient | ExternalRecipient] +) -> EmailDeliveryMethod: + return EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients(whole_workspace=whole_workspace, items=recipients), + subject="Approval Needed", + body="Please review", + ) + ) + + +class TestHumanInputFormRepositoryImplWithContainers: + def test_create_form_with_whole_workspace_recipients(self, db_session_with_containers: Session) -> None: + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + tenant, members = _create_tenant_with_members( + db_session_with_containers, + member_emails=["member1@example.com", "member2@example.com"], + ) + + repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + params = _build_form_params( + delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])], + ) + + form_entity = repository.create_form(params) + + with Session(engine) as verification_session: + recipients = verification_session.scalars( + select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id) + ).all() + + assert len(recipients) == len(members) + member_payloads = [ + EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload) + for recipient in recipients + if recipient.recipient_type == RecipientType.EMAIL_MEMBER + ] + member_emails = {payload.email for payload in member_payloads} + assert member_emails == {member.email for member in members} + + def test_create_form_with_specific_members_and_external(self, db_session_with_containers: Session) -> None: + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + tenant, members = _create_tenant_with_members( + db_session_with_containers, + member_emails=["primary@example.com", "secondary@example.com"], + ) + + repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + params = _build_form_params( + delivery_methods=[ + _build_email_delivery( + whole_workspace=False, + recipients=[ + MemberRecipient(user_id=members[0].id), + ExternalRecipient(email="external@example.com"), + ], + ) + ], + ) + + form_entity = repository.create_form(params) + + with Session(engine) as verification_session: + recipients = verification_session.scalars( + select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id) + ).all() + + member_recipient_payloads = [ + EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload) + for recipient in recipients + if recipient.recipient_type == RecipientType.EMAIL_MEMBER + ] + assert len(member_recipient_payloads) == 1 + assert member_recipient_payloads[0].user_id == members[0].id + + external_payloads = [ + EmailExternalRecipientPayload.model_validate_json(recipient.recipient_payload) + for recipient in recipients + if recipient.recipient_type == RecipientType.EMAIL_EXTERNAL + ] + assert len(external_payloads) == 1 + assert external_payloads[0].email == "external@example.com" diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py new file mode 100644 index 0000000000..649edbb37c --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -0,0 +1,127 @@ +"""Unit tests for HumanInputFormRepositoryImpl private helpers.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.repositories.human_input_reposotiry import ( + HumanInputFormRepositoryImpl, + _WorkspaceMemberInfo, +) +from core.workflow.nodes.human_input.entities import ExternalRecipient, MemberRecipient +from models.human_input import ( + EmailExternalRecipientPayload, + EmailMemberRecipientPayload, + HumanInputFormRecipient, + RecipientType, +) + + +def _build_repository() -> HumanInputFormRepositoryImpl: + return HumanInputFormRepositoryImpl(session_factory=MagicMock(), tenant_id="tenant-id") + + +def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleNamespace]: + created: list[SimpleNamespace] = [] + + def fake_new(cls, form_id: str, delivery_id: str, payload): # type: ignore[no-untyped-def] + recipient = SimpleNamespace( + form_id=form_id, + delivery_id=delivery_id, + recipient_type=payload.TYPE, + recipient_payload=payload.model_dump_json(), + ) + created.append(recipient) + return recipient + + monkeypatch.setattr(HumanInputFormRecipient, "new", classmethod(fake_new)) + return created + + +class TestHumanInputFormRepositoryImplHelpers: + def test_create_email_recipients_with_member_and_external(self, monkeypatch: pytest.MonkeyPatch) -> None: + repo = _build_repository() + session_stub = object() + _patch_recipient_factory(monkeypatch) + + def fake_query(self, session, user_ids): # type: ignore[no-untyped-def] + assert session is session_stub + assert user_ids == ["member-1"] + return [_WorkspaceMemberInfo(user_id="member-1", email="member@example.com")] + + monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members", fake_query) + + recipients = repo._create_email_recipients( + session=session_stub, + form_id="form-id", + delivery_id="delivery-id", + recipients=[ + MemberRecipient(user_id="member-1"), + ExternalRecipient(email="external@example.com"), + ], + ) + + assert len(recipients) == 2 + member_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_MEMBER) + external_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL) + + member_payload = EmailMemberRecipientPayload.model_validate_json(member_recipient.recipient_payload) + assert member_payload.user_id == "member-1" + assert member_payload.email == "member@example.com" + + external_payload = EmailExternalRecipientPayload.model_validate_json(external_recipient.recipient_payload) + assert external_payload.email == "external@example.com" + + def test_create_email_recipients_skips_unknown_members(self, monkeypatch: pytest.MonkeyPatch) -> None: + repo = _build_repository() + session_stub = object() + created = _patch_recipient_factory(monkeypatch) + + def fake_query(self, session, user_ids): # type: ignore[no-untyped-def] + assert session is session_stub + assert user_ids == ["missing-member"] + return [] + + monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members", fake_query) + + recipients = repo._create_email_recipients( + session=session_stub, + form_id="form-id", + delivery_id="delivery-id", + recipients=[ + MemberRecipient(user_id="missing-member"), + ExternalRecipient(email="external@example.com"), + ], + ) + + assert len(recipients) == 1 + assert recipients[0].recipient_type == RecipientType.EMAIL_EXTERNAL + assert len(created) == 1 # only external recipient created via factory + + def test_create_whole_workspace_recipients_uses_all_members(self, monkeypatch: pytest.MonkeyPatch) -> None: + repo = _build_repository() + session_stub = object() + _patch_recipient_factory(monkeypatch) + + def fake_query(self, session, user_ids): # type: ignore[no-untyped-def] + assert session is session_stub + assert user_ids is None + return [ + _WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"), + _WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"), + ] + + monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members", fake_query) + + recipients = repo._create_whole_workspace_recipients( + session=session_stub, + form_id="form-id", + delivery_id="delivery-id", + ) + + assert len(recipients) == 2 + emails = {EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email for r in recipients} + assert emails == {"member1@example.com", "member2@example.com"}