WIP: feat(api): human input service

This commit is contained in:
QuantumGhost 2025-11-18 14:03:42 +08:00
parent c7957d5740
commit c0e15b9e1b
10 changed files with 579 additions and 149 deletions

View File

@ -34,6 +34,7 @@ logger = logging.getLogger(__name__)
class _FormDefinitionWithSite(FormDefinition):
# the site field may be not necessary for console scenario.
site: None

View File

@ -4,55 +4,47 @@ Web App Human Input Form APIs.
import logging
from flask import jsonify
from flask_restful import reqparse
from flask import Response
from flask_restx import reqparse
from controllers.web import api
from controllers.web.error import (
NotFoundError,
)
from controllers.web import web_ns
from controllers.web.error import NotFoundError
from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
from models.human_input import HumanInputSubmissionType
from models.human_input import RecipientType
from models.model import App, EndUser
from services.human_input_service import Form, FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
class HumanInputFormApi(WebApiResource):
"""API for getting human input form definition."""
def _jsonify_form_definition(form: Form) -> Response:
"""Return the Pydantic definition as a JSON response."""
return Response(form.get_definition().model_dump_json(), mimetype="application/json")
def get(self, web_app_form_token: str):
@web_ns.route("/form/human_input/<string:web_app_form_token>")
class HumanInputFormApi(WebApiResource):
"""API for getting and submitting human input forms via the web app."""
def get(self, _app_model: App, _end_user: EndUser, web_app_form_token: str):
"""
Get human input form definition by token.
GET /api/form/human_input/<web_app_form_token>
"""
service = HumanInputService(db.engine)
try:
service = HumanInputFormService(db.session())
form_definition = service.get_form_definition(
identifier=web_app_form_token, is_token=True, include_site_info=True
)
return form_definition, 200
except HumanInputFormNotFoundError:
form = service.get_form_definition_by_token(RecipientType.WEBAPP, web_app_form_token)
except FormNotFoundError:
raise NotFoundError("Form not found")
except HumanInputFormExpiredError:
return jsonify(
{"error_code": "human_input_form_expired", "description": "Human input form has expired"}
), 400
except HumanInputFormAlreadySubmittedError:
return jsonify(
{
"error_code": "human_input_form_submitted",
"description": "Human input form has already been submitted",
}
), 400
if form is None:
raise NotFoundError("Form not found")
class HumanInputFormSubmissionApi(WebApiResource):
"""API for submitting human input forms."""
return _jsonify_form_definition(form)
def post(self, web_app_form_token: str):
def post(self, _app_model: App, _end_user: EndUser, web_app_form_token: str):
"""
Submit human input form by token.
@ -71,36 +63,15 @@ class HumanInputFormSubmissionApi(WebApiResource):
parser.add_argument("action", type=str, required=True, location="json")
args = parser.parse_args()
service = HumanInputService(db.engine)
try:
# Submit the form
service = HumanInputFormService(db.session())
service.submit_form(
identifier=web_app_form_token,
service.submit_form_by_token(
recipient_type=RecipientType.WEBAPP,
form_token=web_app_form_token,
selected_action_id=args["action"],
form_data=args["inputs"],
action=args["action"],
is_token=True,
submission_type=HumanInputSubmissionType.web_app,
)
return {}, 200
except HumanInputFormNotFoundError:
except FormNotFoundError:
raise NotFoundError("Form not found")
except HumanInputFormExpiredError:
return jsonify(
{"error_code": "human_input_form_expired", "description": "Human input form has expired"}
), 400
except HumanInputFormAlreadySubmittedError:
return jsonify(
{
"error_code": "human_input_form_submitted",
"description": "Human input form has already been submitted",
}
), 400
except InvalidFormDataError as e:
return jsonify({"error_code": "invalid_form_data", "description": e.message}), 400
# Register the APIs
api.add_resource(HumanInputFormApi, "/form/human_input/<string:web_app_form_token>")
api.add_resource(HumanInputFormSubmissionApi, "/form/human_input/<string:web_app_form_token>", methods=["POST"])
return {}, 200

View File

@ -1,19 +1,28 @@
"""
Repository implementations for data access.
"""Repository implementations for data access."""
This package contains concrete implementations of the repository interfaces
defined in the core.workflow.repository package.
"""
from __future__ import annotations
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from importlib import import_module
from typing import Any
__all__ = [
"CeleryWorkflowExecutionRepository",
"CeleryWorkflowNodeExecutionRepository",
"DifyCoreRepositoryFactory",
"RepositoryImportError",
"SQLAlchemyWorkflowNodeExecutionRepository",
]
_ATTRIBUTE_MODULE_MAP = {
"CeleryWorkflowExecutionRepository": "core.repositories.celery_workflow_execution_repository",
"CeleryWorkflowNodeExecutionRepository": "core.repositories.celery_workflow_node_execution_repository",
"DifyCoreRepositoryFactory": "core.repositories.factory",
"RepositoryImportError": "core.repositories.factory",
"SQLAlchemyWorkflowNodeExecutionRepository": "core.repositories.sqlalchemy_workflow_node_execution_repository",
}
__all__ = list(_ATTRIBUTE_MODULE_MAP.keys())
def __getattr__(name: str) -> Any:
module_path = _ATTRIBUTE_MODULE_MAP.get(name)
if module_path is None:
raise AttributeError(f"module 'core.repositories' has no attribute '{name}'")
module = import_module(module_path)
return getattr(module, name)
def __dir__() -> list[str]: # pragma: no cover - simple helper
return sorted(__all__)

View File

@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
from typing import Any
from sqlalchemy import Engine, select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker, selectinload
from core.workflow.nodes.human_input.entities import (
DeliveryChannelConfig,
@ -19,17 +19,18 @@ from core.workflow.nodes.human_input.entities import (
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
FormNotFoundError,
FormSubmissionEntity,
FormSubmission,
HumanInputFormEntity,
)
from libs.datetime_utils import naive_utc_now
from libs.uuid_utils import uuidv7
from models.account import Account, TenantAccountJoin
from models.human_input import (
EmailExternalRecipientPayload,
EmailMemberRecipientPayload,
HumanInputDelivery,
HumanInputForm,
HumanInputRecipient,
HumanInputFormRecipient,
RecipientType,
WebAppRecipientPayload,
)
@ -38,14 +39,20 @@ from models.human_input import (
@dataclasses.dataclass(frozen=True)
class _DeliveryAndRecipients:
delivery: HumanInputDelivery
recipients: Sequence[HumanInputRecipient]
recipients: Sequence[HumanInputFormRecipient]
def webapp_recipient(self) -> HumanInputRecipient | None:
def webapp_recipient(self) -> HumanInputFormRecipient | None:
return next((i for i in self.recipients if i.recipient_type == RecipientType.WEBAPP), None)
@dataclasses.dataclass(frozen=True)
class _WorkspaceMemberInfo:
user_id: str
email: str
class _HumanInputFormEntityImpl(HumanInputFormEntity):
def __init__(self, form_model: HumanInputForm, web_app_recipient: HumanInputRecipient | None):
def __init__(self, form_model: HumanInputForm, web_app_recipient: HumanInputFormRecipient | None):
self._form_model = form_model
self._web_app_recipient = web_app_recipient
@ -60,7 +67,7 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity):
return self._web_app_recipient.access_token
class _FormSubmissionEntityImpl(FormSubmissionEntity):
class _FormSubmissionImpl(FormSubmission):
def __init__(self, form_model: HumanInputForm):
self._form_model = form_model
@ -78,37 +85,23 @@ class _FormSubmissionEntityImpl(FormSubmissionEntity):
return json.loads(submitted_data)
class WorkspaceMember:
def user_id(self) -> str:
pass
def email(self) -> str:
pass
class WorkspaceMemberQueirer:
def get_all_workspace_members(self) -> Sequence[WorkspaceMember]:
# TOOD: need a way to query all members in the current workspace.
pass
def get_members_by_ids(self, user_ids: Sequence[str]) -> Sequence[WorkspaceMember]:
pass
class HumanInputFormRepositoryImpl:
def __init__(
self,
session_factory: sessionmaker | Engine,
tenant_id: str,
member_quierer: WorkspaceMemberQueirer,
):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
self._tenant_id = tenant_id
self._member_queirer = member_quierer
def _delivery_method_to_model(self, form_id, delivery_method: DeliveryChannelConfig) -> _DeliveryAndRecipients:
def _delivery_method_to_model(
self,
session: Session,
form_id: str,
delivery_method: DeliveryChannelConfig,
) -> _DeliveryAndRecipients:
delivery_id = str(uuidv7())
delivery_model = HumanInputDelivery(
id=delivery_id,
@ -117,9 +110,9 @@ class HumanInputFormRepositoryImpl:
delivery_config_id=delivery_method.id,
channel_payload=delivery_method.model_dump_json(),
)
recipients: list[HumanInputRecipient] = []
recipients: list[HumanInputFormRecipient] = []
if isinstance(delivery_method, WebAppDeliveryMethod):
recipient_model = HumanInputRecipient(
recipient_model = HumanInputFormRecipient(
form_id=form_id,
delivery_id=delivery_id,
recipient_type=RecipientType.WEBAPP,
@ -129,11 +122,20 @@ class HumanInputFormRepositoryImpl:
elif isinstance(delivery_method, EmailDeliveryMethod):
email_recipients_config = delivery_method.config.recipients
if email_recipients_config.whole_workspace:
recipients.extend(self._create_whole_workspace_recipients(form_id=form_id, delivery_id=delivery_id))
recipients.extend(
self._create_whole_workspace_recipients(
session=session,
form_id=form_id,
delivery_id=delivery_id,
)
)
else:
recipients.extend(
self._create_email_recipients(
form_id=form_id, delivery_id=delivery_id, recipients=email_recipients_config.items
session=session,
form_id=form_id,
delivery_id=delivery_id,
recipients=email_recipients_config.items,
)
)
@ -141,27 +143,34 @@ class HumanInputFormRepositoryImpl:
def _create_email_recipients(
self,
session: Session,
form_id: str,
delivery_id: str,
recipients: Sequence[EmailRecipient],
) -> list[HumanInputRecipient]:
recipient_models: list[HumanInputRecipient] = []
) -> list[HumanInputFormRecipient]:
recipient_models: list[HumanInputFormRecipient] = []
member_user_ids: list[str] = []
for r in recipients:
if isinstance(r, MemberRecipient):
member_user_ids.append(r.user_id)
elif isinstance(r, ExternalRecipient):
recipient_model = HumanInputRecipient.new(
recipient_model = HumanInputFormRecipient.new(
form_id=form_id, delivery_id=delivery_id, payload=EmailExternalRecipientPayload(email=r.email)
)
recipient_models.append(recipient_model)
else:
raise AssertionError(f"unknown recipient type: recipient={r}")
members = self._member_queirer.get_members_by_ids(member_user_ids)
for member in members:
payload = EmailMemberRecipientPayload(user_id=member.user_id(), email=member.email())
recipient_model = HumanInputRecipient.new(
member_entries = {
member.user_id: member.email
for member in self._query_workspace_members(session=session, user_ids=member_user_ids)
}
for user_id in member_user_ids:
email = member_entries.get(user_id)
if email is None:
continue
payload = EmailMemberRecipientPayload(user_id=user_id, email=email)
recipient_model = HumanInputFormRecipient.new(
form_id=form_id,
delivery_id=delivery_id,
payload=payload,
@ -169,12 +178,17 @@ class HumanInputFormRepositoryImpl:
recipient_models.append(recipient_model)
return recipient_models
def _create_whole_workspace_recipients(self, form_id: str, delivery_id: str) -> list[HumanInputRecipient]:
def _create_whole_workspace_recipients(
self,
session: Session,
form_id: str,
delivery_id: str,
) -> list[HumanInputFormRecipient]:
recipeint_models = []
members = self._member_queirer.get_all_workspace_members()
members = self._query_workspace_members(session=session, user_ids=None)
for member in members:
payload = EmailMemberRecipientPayload(user_id=member.user_id(), email=member.email())
recipient_model = HumanInputRecipient.new(
payload = EmailMemberRecipientPayload(user_id=member.user_id, email=member.email)
recipient_model = HumanInputFormRecipient.new(
form_id=form_id,
delivery_id=delivery_id,
payload=payload,
@ -183,6 +197,30 @@ class HumanInputFormRepositoryImpl:
return recipeint_models
def _query_workspace_members(
self,
session: Session,
user_ids: Sequence[str] | None,
) -> list[_WorkspaceMemberInfo]:
unique_ids: set[str] | None
if user_ids is None:
unique_ids = None
else:
unique_ids = {user_id for user_id in user_ids if user_id}
if not unique_ids:
return []
stmt = (
select(Account.id, Account.email)
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
.where(TenantAccountJoin.tenant_id == self._tenant_id)
)
if unique_ids is not None:
stmt = stmt.where(Account.id.in_(unique_ids))
rows = session.execute(stmt).all()
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
form_config: HumanInputNodeData = params.form_config
@ -190,8 +228,12 @@ class HumanInputFormRepositoryImpl:
# Generate unique form ID
form_id = str(uuidv7())
form_definition = FormDefinition(
form_content=form_config.form_content,
inputs=form_config.inputs,
user_actions=form_config.user_actions,
rendered_content=params.rendered_content,
timeout=form_config.timeout,
timeout_unit=form_config.timeout_unit,
)
form_model = HumanInputForm(
id=form_id,
@ -203,9 +245,13 @@ class HumanInputFormRepositoryImpl:
expiration_time=form_config.expiration_time(naive_utc_now()),
)
session.add(form_model)
web_app_recipient: HumanInputRecipient | None = None
web_app_recipient: HumanInputFormRecipient | None = None
for delivery in form_config.delivery_methods:
delivery_and_recipients = self._delivery_method_to_model(form_id, delivery)
delivery_and_recipients = self._delivery_method_to_model(
session=session,
form_id=form_id,
delivery_method=delivery,
)
session.add(delivery_and_recipients.delivery)
session.add_all(delivery_and_recipients.recipients)
if web_app_recipient is None:
@ -214,7 +260,7 @@ class HumanInputFormRepositoryImpl:
return _HumanInputFormEntityImpl(form_model=form_model, web_app_recipient=web_app_recipient)
def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmissionEntity | None:
def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmission | None:
query = select(HumanInputForm).where(
HumanInputForm.workflow_run_id == workflow_execution_id,
HumanInputForm.node_id == node_id,
@ -227,4 +273,13 @@ class HumanInputFormRepositoryImpl:
if form_model.submitted_at is None:
return None
return _FormSubmissionEntityImpl(form_model=form_model)
return _FormSubmissionImpl(form_model=form_model)
def get_form_by_token(self, token: str, recipient_type: RecipientType | None = None):
query = (
select(HumanInputFormRecipient)
.options(selectinload(HumanInputFormRecipient.form))
.where()
with self._session_factory(expire_on_commit=False) as session:
form_recipient = session.qu

View File

@ -48,8 +48,26 @@ class HumanInputFormEntity(abc.ABC):
# webapp delivery?
pass
@property
@abc.abstractmethod
def recipients(self) -> list["HumanInputFormRecipientEntity"]: ...
class FormSubmissionEntity(abc.ABC):
class HumanInputFormRecipientEntity(abc.ABC):
@property
@abc.abstractmethod
def id(self) -> str:
"""id returns the identifer of this recipient."""
...
@property
@abc.abstractmethod
def token(self) -> str:
"""token returns a random string used to submit form"""
...
class FormSubmission(abc.ABC):
@property
@abc.abstractmethod
def selected_action_id(self) -> str:
@ -81,7 +99,7 @@ class HumanInputFormRepository(Protocol):
"""
...
def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmissionEntity | None:
def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmission | None:
"""Retrieve the submission for a specific human input node.
Returns `FormSubmission` if the form has been submitted, or `None` if not.

View File

@ -66,7 +66,7 @@ class HumanInputForm(DefaultFieldsMixin, Base):
back_populates="form",
lazy="raise",
)
completed_by_recipient: Mapped["HumanInputRecipient | None"] = relationship(
completed_by_recipient: Mapped["HumanInputFormRecipient | None"] = relationship(
"HumanInputRecipient",
primaryjoin="HumanInputForm.completed_by_recipient_id == HumanInputRecipient.id",
lazy="raise",
@ -75,7 +75,7 @@ class HumanInputForm(DefaultFieldsMixin, Base):
class HumanInputDelivery(DefaultFieldsMixin, Base):
__tablename__ = "human_input_deliveries"
__tablename__ = "human_input_form_deliveries"
form_id: Mapped[str] = mapped_column(
StringUUID,
@ -94,10 +94,11 @@ class HumanInputDelivery(DefaultFieldsMixin, Base):
back_populates="deliveries",
lazy="raise",
)
recipients: Mapped[list["HumanInputRecipient"]] = relationship(
recipients: Mapped[list["HumanInputFormRecipient"]] = relationship(
"HumanInputRecipient",
uselist=True,
back_populates="delivery",
cascade="all, delete-orphan",
# Require explicit preloading
lazy="raise",
)
@ -135,8 +136,8 @@ RecipientPayload = Annotated[
]
class HumanInputRecipient(DefaultFieldsMixin, Base):
__tablename__ = "human_input_recipients"
class HumanInputFormRecipient(DefaultFieldsMixin, Base):
__tablename__ = "human_input_form_recipients"
form_id: Mapped[str] = mapped_column(
StringUUID,
@ -158,7 +159,19 @@ class HumanInputRecipient(DefaultFieldsMixin, Base):
delivery: Mapped[HumanInputDelivery] = relationship(
"HumanInputDelivery",
uselist=False,
foreign_keys=[delivery_id],
back_populates="recipients",
# Require explicit preloading
lazy="raise",
)
form: Mapped[HumanInputForm] = relationship(
"HumanInputForm",
uselist=False,
foreign_keys=[form_id],
back_populates="recipients",
# Require explicit preloading
lazy="raise",
)

View File

@ -637,14 +637,14 @@ class WorkflowRun(Base):
back_populates="workflow_run",
)
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
@property
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
def created_by_account(self):
created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
@property
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
def created_by_end_user(self):
from .model import EndUser
@ -663,8 +663,8 @@ class WorkflowRun(Base):
def outputs_dict(self) -> Mapping[str, Any]:
return json.loads(self.outputs) if self.outputs else {}
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
@property
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
def message(self):
from .model import Message
@ -672,8 +672,8 @@ class WorkflowRun(Base):
db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
)
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
@property
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
def workflow(self):
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()

View File

@ -1,16 +1,22 @@
import abc
import json
from collections.abc import Mapping
from typing import Any
from sqlalchemy import Engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.workflow.nodes.human_input.entities import FormDefinition
from libs.datetime_utils import naive_utc_now
from libs.exception import BaseHTTPException
from models.human_input import RecipientType
from models.account import Account
from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType
class Form(abc.ABC):
class Form:
def __init__(self, form_model: HumanInputForm):
self._form_model = form_model
@abc.abstractmethod
def get_definition(self) -> FormDefinition:
pass
@ -34,30 +40,102 @@ class FormSubmittedError(HumanInputError, BaseHTTPException):
super().__init__(description=description)
class FormNotFoundError(HumanInputError, BaseException):
class FormNotFoundError(HumanInputError, BaseHTTPException):
error_code = "human_input_form_not_found"
code = 404
class WebAppDeliveryNotEnabledError(HumanInputError, BaseException):
pass
class HumanInputService:
def __init__(
self,
session_factory: sessionmaker | Engine,
session_factory: sessionmaker[Session] | Engine,
):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
def get_form_definition_by_token(self, recipient_type: RecipientType, form_token: str) -> Form:
pass
def get_form_by_token(self, form_token: str) -> Form | None:
query = (
select(HumanInputFormRecipient)
.options(selectinload(HumanInputFormRecipient.form))
.where(HumanInputFormRecipient.access_token == form_token)
)
with self._session_factory(expire_on_commit=False) as session:
recipient = session.scalars(query).first()
if recipient is None:
return None
def get_form_definition_by_id(self, form_id: str) -> Form | None:
pass
return Form(recipient.form)
def submit_form_by_id(self, form_id: str, selected_action_id: str, form_data: Mapping[str, Any]):
pass
def get_form_by_id(self, form_id: str) -> Form | None:
query = select(HumanInputForm).where(HumanInputForm.id == form_id)
with self._session_factory(expire_on_commit=False) as session:
form_model = session.scalars(query).first()
if form_model is None:
return None
def submit_form_by_token(
self, recipient_type: RecipientType, form_token: str, selected_action_id: str, form_data: Mapping[str, Any]
):
pass
return Form(form_model)
def submit_form_by_id(self, form_id: str, user: Account, selected_action_id: str, form_data: Mapping[str, Any]):
recipient_query = (
select(HumanInputFormRecipient)
.options(selectinload(HumanInputFormRecipient.form))
.where(
HumanInputFormRecipient.recipient_type == RecipientType.WEBAPP,
HumanInputFormRecipient.form_id == form_id,
)
)
with self._session_factory(expire_on_commit=False) as session:
recipient_model = session.scalars(recipient_query).first()
if recipient_model is None:
raise WebAppDeliveryNotEnabledError()
form_model = recipient_model.form
form = Form(form_model)
if form.submitted:
raise FormSubmittedError(form_model.id)
with self._session_factory(expire_on_commit=False) as session, session.begin():
form_model.selected_action_id = selected_action_id
form_model.submitted_data = json.dumps(form_data)
form_model.submitted_at = naive_utc_now()
form_model.submission_user_id = user.id
form_model.completed_by_recipient_id = recipient_model.id
session.add(form_model)
# TODO: restart the execution of paused workflow
def submit_form_by_token(self, form_token: str, selected_action_id: str, form_data: Mapping[str, Any]):
recipient_query = (
select(HumanInputFormRecipient)
.options(selectinload(HumanInputFormRecipient.form))
.where(
HumanInputFormRecipient.form_id == form_token,
)
)
with self._session_factory(expire_on_commit=False) as session:
recipient_model = session.scalars(recipient_query).first()
if recipient_model is None:
raise WebAppDeliveryNotEnabledError()
form_model = recipient_model.form
form = Form(form_model)
if form.submitted:
raise FormSubmittedError(form_model.id)
with self._session_factory(expire_on_commit=False) as session, session.begin():
form_model.selected_action_id = selected_action_id
form_model.submitted_data = json.dumps(form_data)
form_model.submitted_at = naive_utc_now()
form_model.submission_user_id = user.id
form_model.completed_by_recipient_id = recipient_model.id
session.add(form_model)

View File

@ -0,0 +1,158 @@
"""TestContainers integration tests for HumanInputFormRepositoryImpl."""
from __future__ import annotations
from uuid import uuid4
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session
from core.repositories.human_input_reposotiry import HumanInputFormRepositoryImpl
from core.workflow.nodes.human_input.entities import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
HumanInputNodeData,
MemberRecipient,
UserAction,
)
from core.workflow.repositories.human_input_form_repository import FormCreateParams
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.human_input import (
EmailExternalRecipientPayload,
EmailMemberRecipientPayload,
HumanInputFormRecipient,
RecipientType,
)
def _create_tenant_with_members(session: Session, member_emails: list[str]) -> tuple[Tenant, list[Account]]:
tenant = Tenant(name="Test Tenant", status="normal")
session.add(tenant)
session.flush()
members: list[Account] = []
for index, email in enumerate(member_emails):
account = Account(
email=email,
name=f"Member {index}",
interface_language="en-US",
status="active",
)
session.add(account)
session.flush()
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.NORMAL,
current=True,
)
session.add(tenant_join)
members.append(account)
session.commit()
return tenant, members
def _build_form_params(delivery_methods: list[EmailDeliveryMethod]) -> FormCreateParams:
form_config = HumanInputNodeData(
title="Human Approval",
delivery_methods=delivery_methods,
form_content="<p>Approve?</p>",
user_actions=[UserAction(id="approve", title="Approve")],
)
return FormCreateParams(
workflow_execution_id=str(uuid4()),
node_id="human-input-node",
form_config=form_config,
rendered_content="<p>Approve?</p>",
)
def _build_email_delivery(
whole_workspace: bool, recipients: list[MemberRecipient | ExternalRecipient]
) -> EmailDeliveryMethod:
return EmailDeliveryMethod(
config=EmailDeliveryConfig(
recipients=EmailRecipients(whole_workspace=whole_workspace, items=recipients),
subject="Approval Needed",
body="Please review",
)
)
class TestHumanInputFormRepositoryImplWithContainers:
def test_create_form_with_whole_workspace_recipients(self, db_session_with_containers: Session) -> None:
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
tenant, members = _create_tenant_with_members(
db_session_with_containers,
member_emails=["member1@example.com", "member2@example.com"],
)
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
params = _build_form_params(
delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])],
)
form_entity = repository.create_form(params)
with Session(engine) as verification_session:
recipients = verification_session.scalars(
select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id)
).all()
assert len(recipients) == len(members)
member_payloads = [
EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload)
for recipient in recipients
if recipient.recipient_type == RecipientType.EMAIL_MEMBER
]
member_emails = {payload.email for payload in member_payloads}
assert member_emails == {member.email for member in members}
def test_create_form_with_specific_members_and_external(self, db_session_with_containers: Session) -> None:
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
tenant, members = _create_tenant_with_members(
db_session_with_containers,
member_emails=["primary@example.com", "secondary@example.com"],
)
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
params = _build_form_params(
delivery_methods=[
_build_email_delivery(
whole_workspace=False,
recipients=[
MemberRecipient(user_id=members[0].id),
ExternalRecipient(email="external@example.com"),
],
)
],
)
form_entity = repository.create_form(params)
with Session(engine) as verification_session:
recipients = verification_session.scalars(
select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id)
).all()
member_recipient_payloads = [
EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload)
for recipient in recipients
if recipient.recipient_type == RecipientType.EMAIL_MEMBER
]
assert len(member_recipient_payloads) == 1
assert member_recipient_payloads[0].user_id == members[0].id
external_payloads = [
EmailExternalRecipientPayload.model_validate_json(recipient.recipient_payload)
for recipient in recipients
if recipient.recipient_type == RecipientType.EMAIL_EXTERNAL
]
assert len(external_payloads) == 1
assert external_payloads[0].email == "external@example.com"

View File

@ -0,0 +1,127 @@
"""Unit tests for HumanInputFormRepositoryImpl private helpers."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.repositories.human_input_reposotiry import (
HumanInputFormRepositoryImpl,
_WorkspaceMemberInfo,
)
from core.workflow.nodes.human_input.entities import ExternalRecipient, MemberRecipient
from models.human_input import (
EmailExternalRecipientPayload,
EmailMemberRecipientPayload,
HumanInputFormRecipient,
RecipientType,
)
def _build_repository() -> HumanInputFormRepositoryImpl:
return HumanInputFormRepositoryImpl(session_factory=MagicMock(), tenant_id="tenant-id")
def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleNamespace]:
created: list[SimpleNamespace] = []
def fake_new(cls, form_id: str, delivery_id: str, payload): # type: ignore[no-untyped-def]
recipient = SimpleNamespace(
form_id=form_id,
delivery_id=delivery_id,
recipient_type=payload.TYPE,
recipient_payload=payload.model_dump_json(),
)
created.append(recipient)
return recipient
monkeypatch.setattr(HumanInputFormRecipient, "new", classmethod(fake_new))
return created
class TestHumanInputFormRepositoryImplHelpers:
def test_create_email_recipients_with_member_and_external(self, monkeypatch: pytest.MonkeyPatch) -> None:
repo = _build_repository()
session_stub = object()
_patch_recipient_factory(monkeypatch)
def fake_query(self, session, user_ids): # type: ignore[no-untyped-def]
assert session is session_stub
assert user_ids == ["member-1"]
return [_WorkspaceMemberInfo(user_id="member-1", email="member@example.com")]
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members", fake_query)
recipients = repo._create_email_recipients(
session=session_stub,
form_id="form-id",
delivery_id="delivery-id",
recipients=[
MemberRecipient(user_id="member-1"),
ExternalRecipient(email="external@example.com"),
],
)
assert len(recipients) == 2
member_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_MEMBER)
external_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL)
member_payload = EmailMemberRecipientPayload.model_validate_json(member_recipient.recipient_payload)
assert member_payload.user_id == "member-1"
assert member_payload.email == "member@example.com"
external_payload = EmailExternalRecipientPayload.model_validate_json(external_recipient.recipient_payload)
assert external_payload.email == "external@example.com"
def test_create_email_recipients_skips_unknown_members(self, monkeypatch: pytest.MonkeyPatch) -> None:
repo = _build_repository()
session_stub = object()
created = _patch_recipient_factory(monkeypatch)
def fake_query(self, session, user_ids): # type: ignore[no-untyped-def]
assert session is session_stub
assert user_ids == ["missing-member"]
return []
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members", fake_query)
recipients = repo._create_email_recipients(
session=session_stub,
form_id="form-id",
delivery_id="delivery-id",
recipients=[
MemberRecipient(user_id="missing-member"),
ExternalRecipient(email="external@example.com"),
],
)
assert len(recipients) == 1
assert recipients[0].recipient_type == RecipientType.EMAIL_EXTERNAL
assert len(created) == 1 # only external recipient created via factory
def test_create_whole_workspace_recipients_uses_all_members(self, monkeypatch: pytest.MonkeyPatch) -> None:
repo = _build_repository()
session_stub = object()
_patch_recipient_factory(monkeypatch)
def fake_query(self, session, user_ids): # type: ignore[no-untyped-def]
assert session is session_stub
assert user_ids is None
return [
_WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"),
_WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"),
]
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members", fake_query)
recipients = repo._create_whole_workspace_recipients(
session=session_stub,
form_id="form-id",
delivery_id="delivery-id",
)
assert len(recipients) == 2
emails = {EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email for r in recipients}
assert emails == {"member1@example.com", "member2@example.com"}