dify/api/services/human_input_delivery_test_service.py
2026-01-06 08:54:06 +08:00

219 lines
7.0 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Protocol
from sqlalchemy import Engine, select
from sqlalchemy.orm import sessionmaker
from core.workflow.nodes.human_input.entities import (
DeliveryChannelConfig,
EmailDeliveryConfig,
EmailDeliveryMethod,
ExternalRecipient,
MemberRecipient,
)
from extensions.ext_database import db
from extensions.ext_mail import mail
from libs.email_template_renderer import render_email_template
from models import Account, TenantAccountJoin
class DeliveryTestStatus(StrEnum):
OK = "ok"
FAILED = "failed"
@dataclass(frozen=True)
class DeliveryTestContext:
tenant_id: str
app_id: str
node_id: str
node_title: str | None
rendered_content: str
template_vars: dict[str, str] = field(default_factory=dict)
@dataclass(frozen=True)
class DeliveryTestResult:
status: DeliveryTestStatus
delivered_to: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
class DeliveryTestError(Exception):
pass
class DeliveryTestUnsupportedError(DeliveryTestError):
pass
class DeliveryTestHandler(Protocol):
def supports(self, method: DeliveryChannelConfig) -> bool: ...
def send_test(
self,
*,
context: DeliveryTestContext,
method: DeliveryChannelConfig,
) -> DeliveryTestResult: ...
class DeliveryTestRegistry:
def __init__(self, handlers: list[DeliveryTestHandler] | None = None) -> None:
self._handlers = list(handlers or [])
def register(self, handler: DeliveryTestHandler) -> None:
self._handlers.append(handler)
def dispatch(
self,
*,
context: DeliveryTestContext,
method: DeliveryChannelConfig,
) -> DeliveryTestResult:
for handler in self._handlers:
if handler.supports(method):
return handler.send_test(context=context, method=method)
raise DeliveryTestUnsupportedError("Delivery method does not support test send.")
@classmethod
def default(cls) -> DeliveryTestRegistry:
return cls([EmailDeliveryTestHandler()])
class HumanInputDeliveryTestService:
def __init__(self, registry: DeliveryTestRegistry | None = None) -> None:
self._registry = registry or DeliveryTestRegistry.default()
def send_test(
self,
*,
context: DeliveryTestContext,
method: DeliveryChannelConfig,
) -> DeliveryTestResult:
return self._registry.dispatch(context=context, method=method)
class EmailDeliveryTestHandler:
def __init__(self, session_factory: sessionmaker | Engine | None = None) -> None:
if session_factory is None:
session_factory = sessionmaker(bind=db.engine)
elif isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
def supports(self, method: DeliveryChannelConfig) -> bool:
return isinstance(method, EmailDeliveryMethod)
def send_test(
self,
*,
context: DeliveryTestContext,
method: DeliveryChannelConfig,
) -> DeliveryTestResult:
if not isinstance(method, EmailDeliveryMethod):
raise DeliveryTestUnsupportedError("Delivery method does not support test send.")
if not mail.is_inited():
raise DeliveryTestError("Mail client is not initialized.")
recipients = self._resolve_recipients(
tenant_id=context.tenant_id,
method=method,
)
if not recipients:
raise DeliveryTestError("No recipients configured for delivery method.")
delivered: list[str] = []
for recipient_email in recipients:
substitutions = self._build_substitutions(
context=context,
recipient_email=recipient_email,
)
subject = render_email_template(method.config.subject, substitutions)
templated_body = EmailDeliveryConfig.replace_url_placeholder(
method.config.body,
substitutions.get("form_link"),
)
body = render_email_template(templated_body, substitutions)
mail.send(
to=recipient_email,
subject=subject,
html=body,
)
delivered.append(recipient_email)
return DeliveryTestResult(status=DeliveryTestStatus.OK, delivered_to=delivered)
def _resolve_recipients(self, *, tenant_id: str, method: EmailDeliveryMethod) -> list[str]:
recipients = method.config.recipients
emails: list[str] = []
member_user_ids: list[str] = []
for recipient in recipients.items:
if isinstance(recipient, MemberRecipient):
member_user_ids.append(recipient.user_id)
elif isinstance(recipient, ExternalRecipient):
if recipient.email:
emails.append(recipient.email)
if recipients.whole_workspace:
member_user_ids = []
member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=None)
emails.extend(member_emails.values())
elif member_user_ids:
member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=member_user_ids)
for user_id in member_user_ids:
email = member_emails.get(user_id)
if email:
emails.append(email)
return list(dict.fromkeys([email for email in emails if email]))
def _query_workspace_member_emails(
self,
*,
tenant_id: str,
user_ids: list[str] | None,
) -> dict[str, str]:
if user_ids is None:
unique_ids = None
else:
unique_ids = {user_id for user_id in user_ids if user_id}
if not unique_ids:
return {}
stmt = (
select(Account.id, Account.email)
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
.where(TenantAccountJoin.tenant_id == tenant_id)
)
if unique_ids is not None:
stmt = stmt.where(Account.id.in_(unique_ids))
with self._session_factory() as session:
rows = session.execute(stmt).all()
return dict(rows)
@staticmethod
def _build_substitutions(
*,
context: DeliveryTestContext,
recipient_email: str,
) -> dict[str, str]:
raw_values: dict[str, str | None] = {
"form_id": "",
"node_title": context.node_title,
"workflow_run_id": "",
"form_token": "",
"form_link": "",
"form_content": context.rendered_content,
"recipient_email": recipient_email,
}
substitutions = {key: value or "" for key, value in raw_values.items()}
if context.template_vars:
substitutions.update({key: value for key, value in context.template_vars.items() if value is not None})
return substitutions