mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
- Replaced calls to `create_trigger_provider_encrypter` and `create_trigger_provider_oauth_encrypter` with a unified `create_provider_encrypter` method, simplifying the encrypter creation process. - Updated the parameters passed to the new method to enhance configuration management and cache handling. These changes improve code clarity and maintainability in the trigger provider service.
531 lines
22 KiB
Python
531 lines
22 KiB
Python
import json
|
|
import logging
|
|
import uuid
|
|
from collections.abc import Mapping
|
|
from typing import Any, Optional
|
|
|
|
from sqlalchemy import desc, func
|
|
from sqlalchemy.orm import Session
|
|
|
|
from configs import dify_config
|
|
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
|
from core.helper.provider_cache import NoOpProviderCredentialCache
|
|
from core.helper.provider_encryption import create_provider_encrypter
|
|
from core.plugin.entities.plugin import TriggerProviderID
|
|
from core.plugin.entities.plugin_daemon import CredentialType
|
|
from core.plugin.impl.oauth import OAuthHandler
|
|
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
|
from core.trigger.entities.api_entities import (
|
|
TriggerProviderApiEntity,
|
|
TriggerProviderSubscriptionApiEntity,
|
|
)
|
|
from core.trigger.trigger_manager import TriggerManager
|
|
from core.trigger.utils.encryption import (
|
|
create_trigger_provider_encrypter_for_properties,
|
|
create_trigger_provider_encrypter_for_subscription,
|
|
delete_cache_for_subscription,
|
|
)
|
|
from extensions.ext_database import db
|
|
from extensions.ext_redis import redis_client
|
|
from models.trigger import TriggerOAuthSystemClient, TriggerOAuthTenantClient, TriggerSubscription
|
|
from models.workflow import WorkflowPluginTrigger
|
|
from services.plugin.plugin_service import PluginService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TriggerProviderService:
|
|
"""Service for managing trigger providers and credentials"""
|
|
|
|
##########################
|
|
# Trigger provider
|
|
##########################
|
|
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
|
|
|
|
@classmethod
|
|
def get_trigger_provider(cls, tenant_id: str, provider: TriggerProviderID) -> TriggerProviderApiEntity:
|
|
"""Get info for a trigger provider"""
|
|
return TriggerManager.get_trigger_provider(tenant_id, provider).to_api_entity()
|
|
|
|
@classmethod
|
|
def list_trigger_providers(cls, tenant_id: str) -> list[TriggerProviderApiEntity]:
|
|
"""List all trigger providers for the current tenant"""
|
|
return [provider.to_api_entity() for provider in TriggerManager.list_all_trigger_providers(tenant_id)]
|
|
|
|
@classmethod
|
|
def list_trigger_provider_subscriptions(
|
|
cls, tenant_id: str, provider_id: TriggerProviderID
|
|
) -> list[TriggerProviderSubscriptionApiEntity]:
|
|
"""List all trigger subscriptions for the current tenant"""
|
|
subscriptions: list[TriggerProviderSubscriptionApiEntity] = []
|
|
workflows_in_use_map: dict[str, int] = {}
|
|
with Session(db.engine, expire_on_commit=False) as session:
|
|
# Get all subscriptions
|
|
subscriptions_db = (
|
|
session.query(TriggerSubscription)
|
|
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
|
|
.order_by(desc(TriggerSubscription.created_at))
|
|
.all()
|
|
)
|
|
subscriptions = [subscription.to_api_entity() for subscription in subscriptions_db]
|
|
if not subscriptions:
|
|
return []
|
|
usage_counts = (
|
|
session.query(
|
|
WorkflowPluginTrigger.subscription_id,
|
|
func.count(func.distinct(WorkflowPluginTrigger.app_id)).label("app_count"),
|
|
)
|
|
.filter(
|
|
WorkflowPluginTrigger.tenant_id == tenant_id,
|
|
WorkflowPluginTrigger.subscription_id.in_([s.id for s in subscriptions]),
|
|
)
|
|
.group_by(WorkflowPluginTrigger.subscription_id)
|
|
.all()
|
|
)
|
|
workflows_in_use_map = {str(row.subscription_id): int(row.app_count) for row in usage_counts}
|
|
|
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
|
for subscription in subscriptions:
|
|
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
|
tenant_id=tenant_id,
|
|
controller=provider_controller,
|
|
subscription=subscription,
|
|
)
|
|
subscription.credentials = encrypter.mask_credentials(subscription.credentials)
|
|
count = workflows_in_use_map.get(subscription.id)
|
|
subscription.workflows_in_use = count if count is not None else 0
|
|
|
|
return subscriptions
|
|
|
|
@classmethod
|
|
def add_trigger_subscription(
|
|
cls,
|
|
tenant_id: str,
|
|
user_id: str,
|
|
name: str,
|
|
provider_id: TriggerProviderID,
|
|
endpoint_id: str,
|
|
credential_type: CredentialType,
|
|
parameters: Mapping[str, Any],
|
|
properties: Mapping[str, Any],
|
|
credentials: Mapping[str, str],
|
|
subscription_id: Optional[str] = None,
|
|
credential_expires_at: int = -1,
|
|
expires_at: int = -1,
|
|
) -> dict:
|
|
"""
|
|
Add a new trigger provider with credentials.
|
|
Supports multiple credential instances per provider.
|
|
|
|
:param tenant_id: Tenant ID
|
|
:param provider_id: Provider identifier (e.g., "plugin_id/provider_name")
|
|
:param credential_type: Type of credential (oauth or api_key)
|
|
:param credentials: Credential data to encrypt and store
|
|
:param name: Optional name for this credential instance
|
|
:param expires_at: OAuth token expiration timestamp
|
|
:return: Success response
|
|
"""
|
|
try:
|
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
|
with Session(db.engine, expire_on_commit=False) as session:
|
|
# Use distributed lock to prevent race conditions
|
|
lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}"
|
|
with redis_client.lock(lock_key, timeout=20):
|
|
# Check provider count limit
|
|
provider_count = (
|
|
session.query(TriggerSubscription)
|
|
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
|
|
.count()
|
|
)
|
|
|
|
if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__:
|
|
raise ValueError(
|
|
f"Maximum number of providers ({cls.__MAX_TRIGGER_PROVIDER_COUNT__}) "
|
|
f"reached for {provider_id}"
|
|
)
|
|
|
|
# Check if name already exists
|
|
existing = (
|
|
session.query(TriggerSubscription)
|
|
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
|
.first()
|
|
)
|
|
if existing:
|
|
raise ValueError(f"Credential name '{name}' already exists for this provider")
|
|
|
|
credential_encrypter, _ = create_provider_encrypter(
|
|
tenant_id=tenant_id,
|
|
config=provider_controller.get_credential_schema_config(credential_type),
|
|
cache=NoOpProviderCredentialCache(),
|
|
)
|
|
|
|
properties_encrypter, _ = create_provider_encrypter(
|
|
tenant_id=tenant_id,
|
|
config=provider_controller.get_properties_schema(),
|
|
cache=NoOpProviderCredentialCache(),
|
|
)
|
|
|
|
# Create provider record
|
|
db_provider = TriggerSubscription(
|
|
id=subscription_id or str(uuid.uuid4()),
|
|
tenant_id=tenant_id,
|
|
user_id=user_id,
|
|
name=name,
|
|
endpoint_id=endpoint_id,
|
|
provider_id=str(provider_id),
|
|
parameters=parameters,
|
|
properties=properties_encrypter.encrypt(dict(properties)),
|
|
credentials=credential_encrypter.encrypt(dict(credentials)),
|
|
credential_type=credential_type.value,
|
|
credential_expires_at=credential_expires_at,
|
|
expires_at=expires_at,
|
|
)
|
|
|
|
session.add(db_provider)
|
|
session.commit()
|
|
|
|
return {"result": "success", "id": str(db_provider.id)}
|
|
|
|
except Exception as e:
|
|
logger.exception("Failed to add trigger provider")
|
|
raise ValueError(str(e))
|
|
|
|
@classmethod
|
|
def get_subscription_by_id(
|
|
cls, tenant_id: str, subscription_id: str | None = None
|
|
) -> TriggerProviderSubscriptionApiEntity | None:
|
|
"""
|
|
Get a trigger subscription by the ID.
|
|
"""
|
|
with Session(db.engine, expire_on_commit=False) as session:
|
|
subscription: TriggerSubscription | None = None
|
|
if subscription_id:
|
|
subscription = (
|
|
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
|
)
|
|
else:
|
|
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id).first()
|
|
if subscription:
|
|
provider_controller = TriggerManager.get_trigger_provider(
|
|
tenant_id, TriggerProviderID(subscription.provider_id)
|
|
)
|
|
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
|
tenant_id=tenant_id,
|
|
controller=provider_controller,
|
|
subscription=subscription,
|
|
)
|
|
subscription.credentials = encrypter.decrypt(subscription.credentials)
|
|
return subscription.to_api_entity()
|
|
return None
|
|
|
|
@classmethod
|
|
def delete_trigger_provider(cls, session: Session, tenant_id: str, subscription_id: str):
|
|
"""
|
|
Delete a trigger provider subscription within an existing session.
|
|
|
|
:param session: Database session
|
|
:param tenant_id: Tenant ID
|
|
:param subscription_id: Subscription instance ID
|
|
:return: Success response
|
|
"""
|
|
db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
|
if not db_provider:
|
|
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
|
|
|
# Clear cache
|
|
session.delete(db_provider)
|
|
delete_cache_for_subscription(
|
|
tenant_id=tenant_id,
|
|
provider_id=db_provider.provider_id,
|
|
subscription_id=db_provider.id,
|
|
)
|
|
|
|
@classmethod
|
|
def refresh_oauth_token(
|
|
cls,
|
|
tenant_id: str,
|
|
subscription_id: str,
|
|
) -> dict:
|
|
"""
|
|
Refresh OAuth token for a trigger provider.
|
|
|
|
:param tenant_id: Tenant ID
|
|
:param subscription_id: Subscription instance ID
|
|
:return: New token info
|
|
"""
|
|
with Session(db.engine) as session:
|
|
db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
|
|
|
if not db_provider:
|
|
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
|
|
|
if db_provider.credential_type != CredentialType.OAUTH2.value:
|
|
raise ValueError("Only OAuth credentials can be refreshed")
|
|
|
|
provider_id = TriggerProviderID(db_provider.provider_id)
|
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
|
# Create encrypter
|
|
encrypter, cache = create_provider_encrypter(
|
|
tenant_id=tenant_id,
|
|
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
|
cache=NoOpProviderCredentialCache(),
|
|
)
|
|
|
|
# Decrypt current credentials
|
|
current_credentials = encrypter.decrypt(db_provider.credentials)
|
|
|
|
# Get OAuth client configuration
|
|
redirect_uri = (
|
|
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{db_provider.provider_id}/trigger/callback"
|
|
)
|
|
system_credentials = cls.get_oauth_client(tenant_id, provider_id)
|
|
|
|
# Refresh token
|
|
oauth_handler = OAuthHandler()
|
|
refreshed_credentials = oauth_handler.refresh_credentials(
|
|
tenant_id=tenant_id,
|
|
user_id=db_provider.user_id,
|
|
plugin_id=provider_id.plugin_id,
|
|
provider=provider_id.provider_name,
|
|
redirect_uri=redirect_uri,
|
|
system_credentials=system_credentials or {},
|
|
credentials=current_credentials,
|
|
)
|
|
|
|
# Update credentials
|
|
db_provider.credentials = encrypter.encrypt(dict(refreshed_credentials.credentials))
|
|
db_provider.expires_at = refreshed_credentials.expires_at
|
|
session.commit()
|
|
|
|
# Clear cache
|
|
cache.delete()
|
|
|
|
return {
|
|
"result": "success",
|
|
"expires_at": refreshed_credentials.expires_at,
|
|
}
|
|
|
|
@classmethod
|
|
def get_oauth_client(cls, tenant_id: str, provider_id: TriggerProviderID) -> Optional[Mapping[str, Any]]:
|
|
"""
|
|
Get OAuth client configuration for a provider.
|
|
First tries tenant-level OAuth, then falls back to system OAuth.
|
|
|
|
:param tenant_id: Tenant ID
|
|
:param provider_id: Provider identifier
|
|
:return: OAuth client configuration or None
|
|
"""
|
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
|
with Session(db.engine, expire_on_commit=False) as session:
|
|
tenant_client: TriggerOAuthTenantClient | None = (
|
|
session.query(TriggerOAuthTenantClient)
|
|
.filter_by(
|
|
tenant_id=tenant_id,
|
|
provider=provider_id.provider_name,
|
|
plugin_id=provider_id.plugin_id,
|
|
enabled=True,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
oauth_params: Mapping[str, Any] | None = None
|
|
if tenant_client:
|
|
encrypter, _ = create_provider_encrypter(
|
|
tenant_id=tenant_id,
|
|
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
|
cache=NoOpProviderCredentialCache(),
|
|
)
|
|
oauth_params = encrypter.decrypt(tenant_client.oauth_params)
|
|
return oauth_params
|
|
|
|
is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id)
|
|
if not is_verified:
|
|
return oauth_params
|
|
|
|
# Check for system-level OAuth client
|
|
system_client: TriggerOAuthSystemClient | None = (
|
|
session.query(TriggerOAuthSystemClient)
|
|
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
|
|
.first()
|
|
)
|
|
|
|
if system_client:
|
|
try:
|
|
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
|
|
except Exception as e:
|
|
raise ValueError(f"Error decrypting system oauth params: {e}")
|
|
|
|
return oauth_params
|
|
|
|
@classmethod
|
|
def save_custom_oauth_client_params(
|
|
cls,
|
|
tenant_id: str,
|
|
provider_id: TriggerProviderID,
|
|
client_params: Optional[dict] = None,
|
|
enabled: Optional[bool] = None,
|
|
) -> dict:
|
|
"""
|
|
Save or update custom OAuth client parameters for a trigger provider.
|
|
|
|
:param tenant_id: Tenant ID
|
|
:param provider_id: Provider identifier
|
|
:param client_params: OAuth client parameters (client_id, client_secret, etc.)
|
|
:param enabled: Enable/disable the custom OAuth client
|
|
:return: Success response
|
|
"""
|
|
if client_params is None and enabled is None:
|
|
return {"result": "success"}
|
|
|
|
# Get provider controller to access schema
|
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
|
|
|
with Session(db.engine) as session:
|
|
# Find existing custom client params
|
|
custom_client = (
|
|
session.query(TriggerOAuthTenantClient)
|
|
.filter_by(
|
|
tenant_id=tenant_id,
|
|
plugin_id=provider_id.plugin_id,
|
|
provider=provider_id.provider_name,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
# Create new record if doesn't exist
|
|
if custom_client is None:
|
|
custom_client = TriggerOAuthTenantClient(
|
|
tenant_id=tenant_id,
|
|
plugin_id=provider_id.plugin_id,
|
|
provider=provider_id.provider_name,
|
|
)
|
|
session.add(custom_client)
|
|
|
|
# Update client params if provided
|
|
if client_params is not None:
|
|
encrypter, cache = create_provider_encrypter(
|
|
tenant_id=tenant_id,
|
|
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
|
cache=NoOpProviderCredentialCache(),
|
|
)
|
|
|
|
# Handle hidden values
|
|
original_params = encrypter.decrypt(custom_client.oauth_params)
|
|
new_params: dict = {
|
|
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
|
for key, value in client_params.items()
|
|
}
|
|
custom_client.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
|
|
cache.delete()
|
|
|
|
# Update enabled status if provided
|
|
if enabled is not None:
|
|
custom_client.enabled = enabled
|
|
|
|
session.commit()
|
|
|
|
return {"result": "success"}
|
|
|
|
@classmethod
|
|
def get_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict:
|
|
"""
|
|
Get custom OAuth client parameters for a trigger provider.
|
|
|
|
:param tenant_id: Tenant ID
|
|
:param provider_id: Provider identifier
|
|
:return: Masked OAuth client parameters
|
|
"""
|
|
with Session(db.engine) as session:
|
|
custom_client = (
|
|
session.query(TriggerOAuthTenantClient)
|
|
.filter_by(
|
|
tenant_id=tenant_id,
|
|
plugin_id=provider_id.plugin_id,
|
|
provider=provider_id.provider_name,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if custom_client is None:
|
|
return {}
|
|
|
|
# Get provider controller to access schema
|
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
|
|
|
# Create encrypter to decrypt and mask values
|
|
encrypter, _ = create_provider_encrypter(
|
|
tenant_id=tenant_id,
|
|
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
|
cache=NoOpProviderCredentialCache(),
|
|
)
|
|
|
|
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_client.oauth_params))
|
|
|
|
@classmethod
|
|
def delete_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict:
|
|
"""
|
|
Delete custom OAuth client parameters for a trigger provider.
|
|
|
|
:param tenant_id: Tenant ID
|
|
:param provider_id: Provider identifier
|
|
:return: Success response
|
|
"""
|
|
with Session(db.engine) as session:
|
|
session.query(TriggerOAuthTenantClient).filter_by(
|
|
tenant_id=tenant_id,
|
|
provider=provider_id.provider_name,
|
|
plugin_id=provider_id.plugin_id,
|
|
).delete()
|
|
session.commit()
|
|
|
|
return {"result": "success"}
|
|
|
|
@classmethod
|
|
def is_oauth_custom_client_enabled(cls, tenant_id: str, provider_id: TriggerProviderID) -> bool:
|
|
"""
|
|
Check if custom OAuth client is enabled for a trigger provider.
|
|
|
|
:param tenant_id: Tenant ID
|
|
:param provider_id: Provider identifier
|
|
:return: True if enabled, False otherwise
|
|
"""
|
|
with Session(db.engine, expire_on_commit=False) as session:
|
|
custom_client = (
|
|
session.query(TriggerOAuthTenantClient)
|
|
.filter_by(
|
|
tenant_id=tenant_id,
|
|
plugin_id=provider_id.plugin_id,
|
|
provider=provider_id.provider_name,
|
|
enabled=True,
|
|
)
|
|
.first()
|
|
)
|
|
return custom_client is not None
|
|
|
|
@classmethod
|
|
def get_subscription_by_endpoint(cls, endpoint_id: str) -> TriggerSubscription | None:
|
|
"""
|
|
Get a trigger subscription by the endpoint ID.
|
|
"""
|
|
with Session(db.engine, expire_on_commit=False) as session:
|
|
subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first()
|
|
if not subscription:
|
|
return None
|
|
provider_controller = TriggerManager.get_trigger_provider(
|
|
subscription.tenant_id, TriggerProviderID(subscription.provider_id)
|
|
)
|
|
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
|
tenant_id=subscription.tenant_id,
|
|
controller=provider_controller,
|
|
subscription=subscription,
|
|
)
|
|
subscription.credentials = credential_encrypter.decrypt(subscription.credentials)
|
|
|
|
properties_encrypter, _ = create_trigger_provider_encrypter_for_properties(
|
|
tenant_id=subscription.tenant_id,
|
|
controller=provider_controller,
|
|
subscription=subscription,
|
|
)
|
|
subscription.properties = properties_encrypter.decrypt(subscription.properties)
|
|
return subscription
|