mirror of
https://github.com/langgenius/dify.git
synced 2026-01-30 07:32:45 +08:00
219 lines
7.0 KiB
Python
219 lines
7.0 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from enum import StrEnum
|
|
from typing import Protocol
|
|
|
|
from sqlalchemy import Engine, select
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
from core.workflow.nodes.human_input.entities import (
|
|
DeliveryChannelConfig,
|
|
EmailDeliveryConfig,
|
|
EmailDeliveryMethod,
|
|
ExternalRecipient,
|
|
MemberRecipient,
|
|
)
|
|
from extensions.ext_database import db
|
|
from extensions.ext_mail import mail
|
|
from libs.email_template_renderer import render_email_template
|
|
from models import Account, TenantAccountJoin
|
|
|
|
|
|
class DeliveryTestStatus(StrEnum):
|
|
OK = "ok"
|
|
FAILED = "failed"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DeliveryTestContext:
|
|
tenant_id: str
|
|
app_id: str
|
|
node_id: str
|
|
node_title: str | None
|
|
rendered_content: str
|
|
template_vars: dict[str, str] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DeliveryTestResult:
|
|
status: DeliveryTestStatus
|
|
delivered_to: list[str] = field(default_factory=list)
|
|
warnings: list[str] = field(default_factory=list)
|
|
|
|
|
|
class DeliveryTestError(Exception):
|
|
pass
|
|
|
|
|
|
class DeliveryTestUnsupportedError(DeliveryTestError):
|
|
pass
|
|
|
|
|
|
class DeliveryTestHandler(Protocol):
|
|
def supports(self, method: DeliveryChannelConfig) -> bool: ...
|
|
|
|
def send_test(
|
|
self,
|
|
*,
|
|
context: DeliveryTestContext,
|
|
method: DeliveryChannelConfig,
|
|
) -> DeliveryTestResult: ...
|
|
|
|
|
|
class DeliveryTestRegistry:
|
|
def __init__(self, handlers: list[DeliveryTestHandler] | None = None) -> None:
|
|
self._handlers = list(handlers or [])
|
|
|
|
def register(self, handler: DeliveryTestHandler) -> None:
|
|
self._handlers.append(handler)
|
|
|
|
def dispatch(
|
|
self,
|
|
*,
|
|
context: DeliveryTestContext,
|
|
method: DeliveryChannelConfig,
|
|
) -> DeliveryTestResult:
|
|
for handler in self._handlers:
|
|
if handler.supports(method):
|
|
return handler.send_test(context=context, method=method)
|
|
raise DeliveryTestUnsupportedError("Delivery method does not support test send.")
|
|
|
|
@classmethod
|
|
def default(cls) -> DeliveryTestRegistry:
|
|
return cls([EmailDeliveryTestHandler()])
|
|
|
|
|
|
class HumanInputDeliveryTestService:
|
|
def __init__(self, registry: DeliveryTestRegistry | None = None) -> None:
|
|
self._registry = registry or DeliveryTestRegistry.default()
|
|
|
|
def send_test(
|
|
self,
|
|
*,
|
|
context: DeliveryTestContext,
|
|
method: DeliveryChannelConfig,
|
|
) -> DeliveryTestResult:
|
|
return self._registry.dispatch(context=context, method=method)
|
|
|
|
|
|
class EmailDeliveryTestHandler:
|
|
def __init__(self, session_factory: sessionmaker | Engine | None = None) -> None:
|
|
if session_factory is None:
|
|
session_factory = sessionmaker(bind=db.engine)
|
|
elif isinstance(session_factory, Engine):
|
|
session_factory = sessionmaker(bind=session_factory)
|
|
self._session_factory = session_factory
|
|
|
|
def supports(self, method: DeliveryChannelConfig) -> bool:
|
|
return isinstance(method, EmailDeliveryMethod)
|
|
|
|
def send_test(
|
|
self,
|
|
*,
|
|
context: DeliveryTestContext,
|
|
method: DeliveryChannelConfig,
|
|
) -> DeliveryTestResult:
|
|
if not isinstance(method, EmailDeliveryMethod):
|
|
raise DeliveryTestUnsupportedError("Delivery method does not support test send.")
|
|
if not mail.is_inited():
|
|
raise DeliveryTestError("Mail client is not initialized.")
|
|
|
|
recipients = self._resolve_recipients(
|
|
tenant_id=context.tenant_id,
|
|
method=method,
|
|
)
|
|
if not recipients:
|
|
raise DeliveryTestError("No recipients configured for delivery method.")
|
|
|
|
delivered: list[str] = []
|
|
for recipient_email in recipients:
|
|
substitutions = self._build_substitutions(
|
|
context=context,
|
|
recipient_email=recipient_email,
|
|
)
|
|
subject = render_email_template(method.config.subject, substitutions)
|
|
templated_body = EmailDeliveryConfig.replace_url_placeholder(
|
|
method.config.body,
|
|
substitutions.get("form_link"),
|
|
)
|
|
body = render_email_template(templated_body, substitutions)
|
|
|
|
mail.send(
|
|
to=recipient_email,
|
|
subject=subject,
|
|
html=body,
|
|
)
|
|
delivered.append(recipient_email)
|
|
|
|
return DeliveryTestResult(status=DeliveryTestStatus.OK, delivered_to=delivered)
|
|
|
|
def _resolve_recipients(self, *, tenant_id: str, method: EmailDeliveryMethod) -> list[str]:
|
|
recipients = method.config.recipients
|
|
emails: list[str] = []
|
|
member_user_ids: list[str] = []
|
|
for recipient in recipients.items:
|
|
if isinstance(recipient, MemberRecipient):
|
|
member_user_ids.append(recipient.user_id)
|
|
elif isinstance(recipient, ExternalRecipient):
|
|
if recipient.email:
|
|
emails.append(recipient.email)
|
|
|
|
if recipients.whole_workspace:
|
|
member_user_ids = []
|
|
member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=None)
|
|
emails.extend(member_emails.values())
|
|
elif member_user_ids:
|
|
member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=member_user_ids)
|
|
for user_id in member_user_ids:
|
|
email = member_emails.get(user_id)
|
|
if email:
|
|
emails.append(email)
|
|
|
|
return list(dict.fromkeys([email for email in emails if email]))
|
|
|
|
def _query_workspace_member_emails(
|
|
self,
|
|
*,
|
|
tenant_id: str,
|
|
user_ids: list[str] | None,
|
|
) -> dict[str, str]:
|
|
if user_ids is None:
|
|
unique_ids = None
|
|
else:
|
|
unique_ids = {user_id for user_id in user_ids if user_id}
|
|
if not unique_ids:
|
|
return {}
|
|
|
|
stmt = (
|
|
select(Account.id, Account.email)
|
|
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
|
|
.where(TenantAccountJoin.tenant_id == tenant_id)
|
|
)
|
|
if unique_ids is not None:
|
|
stmt = stmt.where(Account.id.in_(unique_ids))
|
|
|
|
with self._session_factory() as session:
|
|
rows = session.execute(stmt).all()
|
|
return dict(rows)
|
|
|
|
@staticmethod
|
|
def _build_substitutions(
|
|
*,
|
|
context: DeliveryTestContext,
|
|
recipient_email: str,
|
|
) -> dict[str, str]:
|
|
raw_values: dict[str, str | None] = {
|
|
"form_id": "",
|
|
"node_title": context.node_title,
|
|
"workflow_run_id": "",
|
|
"form_token": "",
|
|
"form_link": "",
|
|
"form_content": context.rendered_content,
|
|
"recipient_email": recipient_email,
|
|
}
|
|
substitutions = {key: value or "" for key, value in raw_values.items()}
|
|
if context.template_vars:
|
|
substitutions.update({key: value for key, value in context.template_vars.items() if value is not None})
|
|
return substitutions
|