mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
- Added `name` parameter to `TriggerSubscriptionBuilderCreateApi` for better subscription identification. - Implemented `delete_cache_for_subscription` function to clear cache associated with trigger subscriptions. - Updated `WorkflowPluginTriggerService` to check for existing subscriptions before creating new plugin triggers, improving error handling. - Refactored `TriggerProviderService` to utilize the new cache deletion method during provider deletion. This improves the overall management of trigger subscriptions and enhances cache efficiency.
320 lines
12 KiB
Python
320 lines
12 KiB
Python
import json
|
|
import logging
|
|
import uuid
|
|
from collections.abc import Mapping
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
from flask import Request, Response
|
|
|
|
from core.plugin.entities.plugin import TriggerProviderID
|
|
from core.plugin.entities.plugin_daemon import CredentialType
|
|
from core.tools.errors import ToolProviderCredentialValidationError
|
|
from core.trigger.entities.api_entities import SubscriptionBuilderApiEntity
|
|
from core.trigger.entities.entities import (
|
|
RequestLog,
|
|
SubscriptionBuilder,
|
|
)
|
|
from core.trigger.provider import PluginTriggerProviderController
|
|
from core.trigger.trigger_manager import TriggerManager
|
|
from core.trigger.utils.encryption import masked_credentials
|
|
from core.trigger.utils.endpoint import parse_endpoint_id
|
|
from extensions.ext_redis import redis_client
|
|
from services.trigger.trigger_provider_service import TriggerProviderService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TriggerSubscriptionBuilderService:
|
|
"""Service for managing trigger providers and credentials"""
|
|
|
|
##########################
|
|
# Trigger provider
|
|
##########################
|
|
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
|
|
|
|
##########################
|
|
# Validation endpoint
|
|
##########################
|
|
__VALIDATION_REQUEST_CACHE_COUNT__ = 10
|
|
__VALIDATION_REQUEST_CACHE_EXPIRE_MS__ = 30 * 60 * 1000
|
|
|
|
@classmethod
|
|
def encode_cache_key(cls, subscription_id: str) -> str:
|
|
return f"trigger:subscription:validation:{subscription_id}"
|
|
|
|
@classmethod
|
|
def verify_trigger_subscription_builder(
|
|
cls,
|
|
tenant_id: str,
|
|
user_id: str,
|
|
provider_id: TriggerProviderID,
|
|
subscription_builder_id: str,
|
|
) -> Mapping[str, Any]:
|
|
"""Verify a trigger subscription builder"""
|
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
|
if not provider_controller:
|
|
raise ValueError(f"Provider {provider_id} not found")
|
|
|
|
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
|
if not subscription_builder:
|
|
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
|
|
|
if subscription_builder.credential_type == CredentialType.OAUTH2:
|
|
return {"verified": bool(subscription_builder.credentials)}
|
|
|
|
if subscription_builder.credential_type == CredentialType.API_KEY:
|
|
try:
|
|
provider_controller.validate_credentials(user_id, subscription_builder.credentials)
|
|
return {"verified": True}
|
|
except ToolProviderCredentialValidationError as e:
|
|
raise ValueError(f"Invalid credentials: {e}")
|
|
|
|
return {"verified": True}
|
|
|
|
@classmethod
|
|
def build_trigger_subscription_builder(
|
|
cls,
|
|
tenant_id: str,
|
|
user_id: str,
|
|
provider_id: TriggerProviderID,
|
|
subscription_builder_id: str,
|
|
) -> None:
|
|
"""Build a trigger subscription builder"""
|
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
|
if not provider_controller:
|
|
raise ValueError(f"Provider {provider_id} not found")
|
|
|
|
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
|
if not subscription_builder:
|
|
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
|
|
|
if not subscription_builder.name:
|
|
raise ValueError("Subscription builder name is required")
|
|
|
|
credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value)
|
|
if credential_type == CredentialType.UNAUTHORIZED:
|
|
# manually create
|
|
TriggerProviderService.add_trigger_provider(
|
|
subscription_id=subscription_builder.id,
|
|
tenant_id=tenant_id,
|
|
user_id=user_id,
|
|
name=subscription_builder.name,
|
|
provider_id=provider_id,
|
|
endpoint_id=subscription_builder.endpoint_id,
|
|
parameters=subscription_builder.parameters,
|
|
properties=subscription_builder.properties,
|
|
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
|
expires_at=subscription_builder.expires_at,
|
|
credentials=subscription_builder.credentials,
|
|
credential_type=credential_type,
|
|
)
|
|
else:
|
|
# automatically create
|
|
subscription = TriggerManager.subscribe_trigger(
|
|
tenant_id=tenant_id,
|
|
user_id=user_id,
|
|
provider_id=provider_id,
|
|
endpoint=parse_endpoint_id(subscription_builder.endpoint_id),
|
|
parameters=subscription_builder.parameters,
|
|
credentials=subscription_builder.credentials,
|
|
)
|
|
|
|
TriggerProviderService.add_trigger_provider(
|
|
subscription_id=subscription_builder.id,
|
|
tenant_id=tenant_id,
|
|
user_id=user_id,
|
|
name=subscription_builder.name,
|
|
provider_id=provider_id,
|
|
endpoint_id=subscription_builder.endpoint_id,
|
|
parameters=subscription_builder.parameters,
|
|
properties=subscription.properties,
|
|
credentials=subscription_builder.credentials,
|
|
credential_type=credential_type,
|
|
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
|
expires_at=subscription_builder.expires_at,
|
|
)
|
|
|
|
cls.delete_trigger_subscription_builder(subscription_builder_id)
|
|
|
|
@classmethod
|
|
def create_trigger_subscription_builder(
|
|
cls,
|
|
tenant_id: str,
|
|
user_id: str,
|
|
provider_id: TriggerProviderID,
|
|
credentials: Mapping[str, str],
|
|
credential_type: CredentialType,
|
|
credential_expires_at: int,
|
|
expires_at: int,
|
|
name: str | None,
|
|
) -> SubscriptionBuilder:
|
|
"""
|
|
Add a new trigger subscription validation.
|
|
"""
|
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
|
if not provider_controller:
|
|
raise ValueError(f"Provider {provider_id} not found")
|
|
|
|
subscription_schema = provider_controller.get_subscription_schema()
|
|
subscription_id = str(uuid.uuid4())
|
|
subscription_builder = SubscriptionBuilder(
|
|
id=subscription_id,
|
|
name=name or "",
|
|
endpoint_id=subscription_id,
|
|
tenant_id=tenant_id,
|
|
user_id=user_id,
|
|
provider_id=str(provider_id),
|
|
parameters=subscription_schema.get_default_parameters(),
|
|
properties=subscription_schema.get_default_properties(),
|
|
credentials=credentials,
|
|
credential_type=credential_type,
|
|
credential_expires_at=credential_expires_at,
|
|
expires_at=expires_at,
|
|
)
|
|
cache_key = cls.encode_cache_key(subscription_id)
|
|
redis_client.setex(
|
|
cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder.model_dump_json()
|
|
)
|
|
return subscription_builder
|
|
|
|
@classmethod
|
|
def update_trigger_subscription_builder(
|
|
cls,
|
|
tenant_id: str,
|
|
provider_id: TriggerProviderID,
|
|
subscription_builder_id: str,
|
|
name: str | None,
|
|
parameters: Mapping[str, Any] | None,
|
|
properties: Mapping[str, Any] | None,
|
|
credentials: Mapping[str, str] | None,
|
|
) -> SubscriptionBuilderApiEntity:
|
|
"""
|
|
Update a trigger subscription validation.
|
|
"""
|
|
subscription_id = subscription_builder_id
|
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
|
if not provider_controller:
|
|
raise ValueError(f"Provider {provider_id} not found")
|
|
|
|
cache_key = cls.encode_cache_key(subscription_id)
|
|
subscription_builder = cls.get_subscription_builder(subscription_id)
|
|
if not subscription_builder or subscription_builder.tenant_id != tenant_id:
|
|
raise ValueError(f"Subscription {subscription_id} expired or not found")
|
|
|
|
if name:
|
|
subscription_builder.name = name
|
|
if parameters:
|
|
subscription_builder.parameters = parameters
|
|
if properties:
|
|
subscription_builder.properties = properties
|
|
if credentials:
|
|
subscription_builder.credentials = credentials
|
|
|
|
redis_client.setex(
|
|
cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder.model_dump_json()
|
|
)
|
|
return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder)
|
|
|
|
@classmethod
|
|
def builder_to_api_entity(
|
|
cls, controller: PluginTriggerProviderController, entity: SubscriptionBuilder
|
|
) -> SubscriptionBuilderApiEntity:
|
|
return SubscriptionBuilderApiEntity(
|
|
id=entity.id,
|
|
name=entity.name or "",
|
|
provider=entity.provider_id,
|
|
endpoint=parse_endpoint_id(entity.endpoint_id),
|
|
parameters=entity.parameters,
|
|
properties=entity.properties,
|
|
credential_type=CredentialType(entity.credential_type),
|
|
credentials=masked_credentials(
|
|
schemas=controller.get_credentials_schema(CredentialType(entity.credential_type)),
|
|
credentials=entity.credentials,
|
|
),
|
|
)
|
|
|
|
@classmethod
|
|
def delete_trigger_subscription_builder(cls, subscription_id: str) -> None:
|
|
"""
|
|
Delete a trigger subscription validation.
|
|
"""
|
|
cache_key = cls.encode_cache_key(subscription_id)
|
|
redis_client.delete(cache_key)
|
|
|
|
@classmethod
|
|
def get_subscription_builder(cls, endpoint_id: str) -> SubscriptionBuilder | None:
|
|
"""
|
|
Get a trigger subscription by the endpoint ID.
|
|
"""
|
|
cache_key = cls.encode_cache_key(endpoint_id)
|
|
subscription_cache = redis_client.get(cache_key)
|
|
if subscription_cache:
|
|
return SubscriptionBuilder.model_validate(json.loads(subscription_cache))
|
|
|
|
return None
|
|
|
|
@classmethod
|
|
def append_log(cls, endpoint_id: str, request: Request, response: Response) -> None:
|
|
"""Append validation request log to Redis."""
|
|
log = RequestLog(
|
|
id=str(uuid.uuid4()),
|
|
endpoint=endpoint_id,
|
|
request={
|
|
"method": request.method,
|
|
"url": request.url,
|
|
"headers": dict(request.headers),
|
|
"data": request.get_data(as_text=True),
|
|
},
|
|
response={
|
|
"status_code": response.status_code,
|
|
"headers": dict(response.headers),
|
|
"data": response.get_data(as_text=True),
|
|
},
|
|
created_at=datetime.now(),
|
|
)
|
|
|
|
key = f"trigger:subscription:validation:logs:{endpoint_id}"
|
|
logs = json.loads(redis_client.get(key) or "[]")
|
|
logs.append(log.model_dump(mode="json"))
|
|
|
|
# Keep last N logs
|
|
logs = logs[-cls.__VALIDATION_REQUEST_CACHE_COUNT__ :]
|
|
redis_client.setex(key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, json.dumps(logs, default=str))
|
|
|
|
@classmethod
|
|
def list_logs(cls, endpoint_id: str) -> list[RequestLog]:
|
|
"""List request logs for validation endpoint."""
|
|
key = f"trigger:subscription:validation:logs:{endpoint_id}"
|
|
logs_json = redis_client.get(key)
|
|
if not logs_json:
|
|
return []
|
|
return [RequestLog.model_validate(log) for log in json.loads(logs_json)]
|
|
|
|
@classmethod
|
|
def process_builder_validation_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
|
|
"""
|
|
Process a temporary endpoint request.
|
|
|
|
:param endpoint_id: The endpoint identifier
|
|
:param request: The Flask request object
|
|
:return: The Flask response object
|
|
"""
|
|
# check if validation endpoint exists
|
|
subscription_builder = cls.get_subscription_builder(endpoint_id)
|
|
if not subscription_builder:
|
|
return None
|
|
|
|
# response to validation endpoint
|
|
controller = TriggerManager.get_trigger_provider(
|
|
subscription_builder.tenant_id, TriggerProviderID(subscription_builder.provider_id)
|
|
)
|
|
response = controller.dispatch(
|
|
user_id=subscription_builder.user_id,
|
|
request=request,
|
|
subscription=subscription_builder.to_subscription(),
|
|
)
|
|
# append the request log
|
|
cls.append_log(endpoint_id, request, response.response)
|
|
return response.response
|