From 925825a41b68ee1432414cfad18d3c42237558cb Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 9 Jan 2026 16:50:32 +0800 Subject: [PATCH] refactor(encryption): using oauth encryption as a general encryption util. --- ...uth_encryption.py => system_encryption.py} | 82 +++--- .../tools/builtin_tools_manage_service.py | 4 +- .../trigger/trigger_provider_service.py | 4 +- .../test_system_encryption.py} | 272 +++++++++--------- 4 files changed, 181 insertions(+), 181 deletions(-) rename api/core/tools/utils/{system_oauth_encryption.py => system_encryption.py} (57%) rename api/tests/unit_tests/utils/{oauth_encryption/test_system_oauth_encryption.py => encryption/test_system_encryption.py} (68%) diff --git a/api/core/tools/utils/system_oauth_encryption.py b/api/core/tools/utils/system_encryption.py similarity index 57% rename from api/core/tools/utils/system_oauth_encryption.py rename to api/core/tools/utils/system_encryption.py index 6b7007842d..fa4625608b 100644 --- a/api/core/tools/utils/system_oauth_encryption.py +++ b/api/core/tools/utils/system_encryption.py @@ -14,23 +14,23 @@ from configs import dify_config logger = logging.getLogger(__name__) -class OAuthEncryptionError(Exception): - """OAuth encryption/decryption specific error""" +class EncryptionError(Exception): + """Encryption/decryption specific error""" pass -class SystemOAuthEncrypter: +class SystemEncrypter: """ - A simple OAuth parameters encrypter using AES-CBC encryption. + A simple parameters encrypter using AES-CBC encryption. - This class provides methods to encrypt and decrypt OAuth parameters + This class provides methods to encrypt and decrypt parameters using AES-CBC mode with a key derived from the application's SECRET_KEY. """ def __init__(self, secret_key: str | None = None): """ - Initialize the OAuth encrypter. + Initialize the encrypter. Args: secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY @@ -43,19 +43,19 @@ class SystemOAuthEncrypter: # Generate a fixed 256-bit key using SHA-256 self.key = hashlib.sha256(secret_key.encode()).digest() - def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str: + def encrypt_params(self, params: Mapping[str, Any]) -> str: """ - Encrypt OAuth parameters. + Encrypt parameters. Args: - oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"} + params: parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"} Returns: Base64-encoded encrypted string Raises: - OAuthEncryptionError: If encryption fails - ValueError: If oauth_params is invalid + EncryptionError: If encryption fails + ValueError: If params is invalid """ try: @@ -66,7 +66,7 @@ class SystemOAuthEncrypter: cipher = AES.new(self.key, AES.MODE_CBC, iv) # Encrypt data - padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size) + padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size) encrypted_data = cipher.encrypt(padded_data) # Combine IV and encrypted data @@ -76,20 +76,20 @@ class SystemOAuthEncrypter: return base64.b64encode(combined).decode() except Exception as e: - raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e + raise EncryptionError(f"Encryption failed: {str(e)}") from e - def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]: + def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]: """ - Decrypt OAuth parameters. + Decrypt parameters. Args: encrypted_data: Base64-encoded encrypted string Returns: - Decrypted OAuth parameters dictionary + Decrypted parameters dictionary Raises: - OAuthEncryptionError: If decryption fails + EncryptionError: If decryption fails ValueError: If encrypted_data is invalid """ if not isinstance(encrypted_data, str): @@ -118,70 +118,70 @@ class SystemOAuthEncrypter: unpadded_data = unpad(decrypted_data, AES.block_size) # Parse JSON - oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data) + params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data) - if not isinstance(oauth_params, dict): + if not isinstance(params, dict): raise ValueError("Decrypted data is not a valid dictionary") - return oauth_params + return params except Exception as e: - raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e + raise EncryptionError(f"Decryption failed: {str(e)}") from e # Factory function for creating encrypter instances -def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter: +def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter: """ - Create an OAuth encrypter instance. + Create an encrypter instance. Args: secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY Returns: - SystemOAuthEncrypter instance + SystemEncrypter instance """ - return SystemOAuthEncrypter(secret_key=secret_key) + return SystemEncrypter(secret_key=secret_key) # Global encrypter instance (for backward compatibility) -_oauth_encrypter: SystemOAuthEncrypter | None = None +_encrypter: SystemEncrypter | None = None -def get_system_oauth_encrypter() -> SystemOAuthEncrypter: +def get_system_encrypter() -> SystemEncrypter: """ - Get the global OAuth encrypter instance. + Get the global encrypter instance. Returns: - SystemOAuthEncrypter instance + SystemEncrypter instance """ - global _oauth_encrypter - if _oauth_encrypter is None: - _oauth_encrypter = SystemOAuthEncrypter() - return _oauth_encrypter + global _encrypter + if _encrypter is None: + _encrypter = SystemEncrypter() + return _encrypter # Convenience functions for backward compatibility -def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str: +def encrypt_system_params(params: Mapping[str, Any]) -> str: """ - Encrypt OAuth parameters using the global encrypter. + Encrypt parameters using the global encrypter. Args: - oauth_params: OAuth parameters dictionary + params: parameters dictionary Returns: Base64-encoded encrypted string """ - return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params) + return get_system_encrypter().encrypt_params(params) -def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]: +def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]: """ - Decrypt OAuth parameters using the global encrypter. + Decrypt parameters using the global encrypter. Args: encrypted_data: Base64-encoded encrypted string Returns: - Decrypted OAuth parameters dictionary + Decrypted parameters dictionary """ - return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data) + return get_system_encrypter().decrypt_params(encrypted_data) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 6797a67dde..4c37a867b1 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -26,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_provider_encrypter -from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params +from core.tools.utils.system_encryption import decrypt_system_params from extensions.ext_database import db from extensions.ext_redis import redis_client from models.provider_ids import ToolProviderID @@ -502,7 +502,7 @@ class BuiltinToolManageService: ) if system_client: try: - oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) + oauth_params = decrypt_system_params(system_client.encrypted_oauth_params) except Exception as e: raise ValueError(f"Error decrypting system oauth params: {e}") diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 688993c798..9e55cb14bf 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -14,7 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler -from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params +from core.tools.utils.system_encryption import decrypt_system_params from core.trigger.entities.api_entities import ( TriggerProviderApiEntity, TriggerProviderSubscriptionApiEntity, @@ -591,7 +591,7 @@ class TriggerProviderService: if system_client: try: - oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) + oauth_params = decrypt_system_params(system_client.encrypted_oauth_params) except Exception as e: raise ValueError(f"Error decrypting system oauth params: {e}") diff --git a/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py b/api/tests/unit_tests/utils/encryption/test_system_encryption.py similarity index 68% rename from api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py rename to api/tests/unit_tests/utils/encryption/test_system_encryption.py index e2607f0fb1..cfa381eb21 100644 --- a/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py +++ b/api/tests/unit_tests/utils/encryption/test_system_encryption.py @@ -7,13 +7,13 @@ from Crypto.Cipher import AES from Crypto.Random import get_random_bytes from Crypto.Util.Padding import pad -from core.tools.utils.system_oauth_encryption import ( - OAuthEncryptionError, - SystemOAuthEncrypter, - create_system_oauth_encrypter, - decrypt_system_oauth_params, - encrypt_system_oauth_params, - get_system_oauth_encrypter, +from core.tools.utils.system_encryption import ( + EncryptionError, + SystemEncrypter, + create_system_encrypter, + decrypt_system_params, + encrypt_system_params, + get_system_encrypter, ) @@ -23,7 +23,7 @@ class TestSystemOAuthEncrypter: def test_init_with_secret_key(self): """Test initialization with provided secret key""" secret_key = "test_secret_key" - encrypter = SystemOAuthEncrypter(secret_key=secret_key) + encrypter = SystemEncrypter(secret_key=secret_key) expected_key = hashlib.sha256(secret_key.encode()).digest() assert encrypter.key == expected_key @@ -31,13 +31,13 @@ class TestSystemOAuthEncrypter: """Test initialization with None secret key falls back to config""" with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: mock_config.SECRET_KEY = "config_secret" - encrypter = SystemOAuthEncrypter(secret_key=None) + encrypter = SystemEncrypter(secret_key=None) expected_key = hashlib.sha256(b"config_secret").digest() assert encrypter.key == expected_key def test_init_with_empty_secret_key(self): """Test initialization with empty secret key""" - encrypter = SystemOAuthEncrypter(secret_key="") + encrypter = SystemEncrypter(secret_key="") expected_key = hashlib.sha256(b"").digest() assert encrypter.key == expected_key @@ -45,16 +45,16 @@ class TestSystemOAuthEncrypter: """Test initialization without secret key uses config""" with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: mock_config.SECRET_KEY = "default_secret" - encrypter = SystemOAuthEncrypter() + encrypter = SystemEncrypter() expected_key = hashlib.sha256(b"default_secret").digest() assert encrypter.key == expected_key def test_encrypt_oauth_params_basic(self): """Test basic OAuth parameters encryption""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} - encrypted = encrypter.encrypt_oauth_params(oauth_params) + encrypted = encrypter.encrypt_params(oauth_params) assert isinstance(encrypted, str) assert len(encrypted) > 0 @@ -66,16 +66,16 @@ class TestSystemOAuthEncrypter: def test_encrypt_oauth_params_empty_dict(self): """Test encryption with empty dictionary""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") oauth_params = {} - encrypted = encrypter.encrypt_oauth_params(oauth_params) + encrypted = encrypter.encrypt_params(oauth_params) assert isinstance(encrypted, str) assert len(encrypted) > 0 def test_encrypt_oauth_params_complex_data(self): """Test encryption with complex data structures""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") oauth_params = { "client_id": "test_id", "client_secret": "test_secret", @@ -86,64 +86,64 @@ class TestSystemOAuthEncrypter: "null_value": None, } - encrypted = encrypter.encrypt_oauth_params(oauth_params) + encrypted = encrypter.encrypt_params(oauth_params) assert isinstance(encrypted, str) assert len(encrypted) > 0 def test_encrypt_oauth_params_unicode_data(self): """Test encryption with unicode data""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") oauth_params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case πŸš€"} - encrypted = encrypter.encrypt_oauth_params(oauth_params) + encrypted = encrypter.encrypt_params(oauth_params) assert isinstance(encrypted, str) assert len(encrypted) > 0 def test_encrypt_oauth_params_large_data(self): """Test encryption with large data""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") oauth_params = { "client_id": "test_id", "large_data": "x" * 10000, # 10KB of data } - encrypted = encrypter.encrypt_oauth_params(oauth_params) + encrypted = encrypter.encrypt_params(oauth_params) assert isinstance(encrypted, str) assert len(encrypted) > 0 def test_encrypt_oauth_params_invalid_input(self): """Test encryption with invalid input types""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") with pytest.raises(Exception): # noqa: B017 - encrypter.encrypt_oauth_params(None) + encrypter.encrypt_params(None) with pytest.raises(Exception): # noqa: B017 - encrypter.encrypt_oauth_params("not_a_dict") + encrypter.encrypt_params("not_a_dict") def test_decrypt_oauth_params_basic(self): """Test basic OAuth parameters decryption""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") original_params = {"client_id": "test_id", "client_secret": "test_secret"} - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) assert decrypted == original_params def test_decrypt_oauth_params_empty_dict(self): """Test decryption of empty dictionary""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") original_params = {} - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) assert decrypted == original_params def test_decrypt_oauth_params_complex_data(self): """Test decryption with complex data structures""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") original_params = { "client_id": "test_id", "client_secret": "test_secret", @@ -154,104 +154,104 @@ class TestSystemOAuthEncrypter: "null_value": None, } - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) assert decrypted == original_params def test_decrypt_oauth_params_unicode_data(self): """Test decryption with unicode data""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") original_params = { "client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case πŸš€", } - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) assert decrypted == original_params def test_decrypt_oauth_params_large_data(self): """Test decryption with large data""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") original_params = { "client_id": "test_id", "large_data": "x" * 10000, # 10KB of data } - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) assert decrypted == original_params def test_decrypt_oauth_params_invalid_base64(self): """Test decryption with invalid base64 data""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") - with pytest.raises(OAuthEncryptionError): - encrypter.decrypt_oauth_params("invalid_base64!") + with pytest.raises(EncryptionError): + encrypter.decrypt_params("invalid_base64!") def test_decrypt_oauth_params_empty_string(self): """Test decryption with empty string""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params("") + encrypter.decrypt_params("") assert "encrypted_data cannot be empty" in str(exc_info.value) def test_decrypt_oauth_params_non_string_input(self): """Test decryption with non-string input""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(123) + encrypter.decrypt_params(123) assert "encrypted_data must be a string" in str(exc_info.value) with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(None) + encrypter.decrypt_params(None) assert "encrypted_data must be a string" in str(exc_info.value) def test_decrypt_oauth_params_too_short_data(self): """Test decryption with too short encrypted data""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") # Create data that's too short (less than 32 bytes) short_data = base64.b64encode(b"short").decode() - with pytest.raises(OAuthEncryptionError) as exc_info: - encrypter.decrypt_oauth_params(short_data) + with pytest.raises(EncryptionError) as exc_info: + encrypter.decrypt_params(short_data) assert "Invalid encrypted data format" in str(exc_info.value) def test_decrypt_oauth_params_corrupted_data(self): """Test decryption with corrupted data""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") # Create corrupted data (valid base64 but invalid encrypted content) corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage - with pytest.raises(OAuthEncryptionError): - encrypter.decrypt_oauth_params(corrupted_data) + with pytest.raises(EncryptionError): + encrypter.decrypt_params(corrupted_data) def test_decrypt_oauth_params_wrong_key(self): """Test decryption with wrong key""" - encrypter1 = SystemOAuthEncrypter("secret1") - encrypter2 = SystemOAuthEncrypter("secret2") + encrypter1 = SystemEncrypter("secret1") + encrypter2 = SystemEncrypter("secret2") original_params = {"client_id": "test_id", "client_secret": "test_secret"} - encrypted = encrypter1.encrypt_oauth_params(original_params) + encrypted = encrypter1.encrypt_params(original_params) - with pytest.raises(OAuthEncryptionError): - encrypter2.decrypt_oauth_params(encrypted) + with pytest.raises(EncryptionError): + encrypter2.decrypt_params(encrypted) def test_encryption_decryption_consistency(self): """Test that encryption and decryption are consistent""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") test_cases = [ {}, @@ -264,42 +264,42 @@ class TestSystemOAuthEncrypter: ] for original_params in test_cases: - encrypted = encrypter.encrypt_oauth_params(original_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypted = encrypter.encrypt_params(original_params) + decrypted = encrypter.decrypt_params(encrypted) assert decrypted == original_params, f"Failed for case: {original_params}" def test_encryption_randomness(self): """Test that encryption produces different results for same input""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} - encrypted1 = encrypter.encrypt_oauth_params(oauth_params) - encrypted2 = encrypter.encrypt_oauth_params(oauth_params) + encrypted1 = encrypter.encrypt_params(oauth_params) + encrypted2 = encrypter.encrypt_params(oauth_params) # Should be different due to random IV assert encrypted1 != encrypted2 # But should decrypt to same result - decrypted1 = encrypter.decrypt_oauth_params(encrypted1) - decrypted2 = encrypter.decrypt_oauth_params(encrypted2) + decrypted1 = encrypter.decrypt_params(encrypted1) + decrypted2 = encrypter.decrypt_params(encrypted2) assert decrypted1 == decrypted2 == oauth_params def test_different_secret_keys_produce_different_results(self): """Test that different secret keys produce different encrypted results""" - encrypter1 = SystemOAuthEncrypter("secret1") - encrypter2 = SystemOAuthEncrypter("secret2") + encrypter1 = SystemEncrypter("secret1") + encrypter2 = SystemEncrypter("secret2") oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} - encrypted1 = encrypter1.encrypt_oauth_params(oauth_params) - encrypted2 = encrypter2.encrypt_oauth_params(oauth_params) + encrypted1 = encrypter1.encrypt_params(oauth_params) + encrypted2 = encrypter2.encrypt_params(oauth_params) # Should produce different encrypted results assert encrypted1 != encrypted2 # But each should decrypt correctly with its own key - decrypted1 = encrypter1.decrypt_oauth_params(encrypted1) - decrypted2 = encrypter2.decrypt_oauth_params(encrypted2) + decrypted1 = encrypter1.decrypt_params(encrypted1) + decrypted2 = encrypter2.decrypt_params(encrypted2) assert decrypted1 == decrypted2 == oauth_params @patch("core.tools.utils.system_oauth_encryption.get_random_bytes") @@ -307,11 +307,11 @@ class TestSystemOAuthEncrypter: """Test encryption when crypto operation fails""" mock_get_random_bytes.side_effect = Exception("Crypto error") - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") oauth_params = {"client_id": "test_id"} - with pytest.raises(OAuthEncryptionError) as exc_info: - encrypter.encrypt_oauth_params(oauth_params) + with pytest.raises(EncryptionError) as exc_info: + encrypter.encrypt_params(oauth_params) assert "Encryption failed" in str(exc_info.value) @@ -320,17 +320,17 @@ class TestSystemOAuthEncrypter: """Test encryption when JSON serialization fails""" mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error") - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") oauth_params = {"client_id": "test_id"} - with pytest.raises(OAuthEncryptionError) as exc_info: - encrypter.encrypt_oauth_params(oauth_params) + with pytest.raises(EncryptionError) as exc_info: + encrypter.encrypt_params(oauth_params) assert "Encryption failed" in str(exc_info.value) def test_decrypt_oauth_params_invalid_json(self): """Test decryption with invalid JSON data""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") # Create valid encrypted data but with invalid JSON content iv = get_random_bytes(16) @@ -341,14 +341,14 @@ class TestSystemOAuthEncrypter: combined = iv + encrypted_data encoded = base64.b64encode(combined).decode() - with pytest.raises(OAuthEncryptionError): - encrypter.decrypt_oauth_params(encoded) + with pytest.raises(EncryptionError): + encrypter.decrypt_params(encoded) def test_key_derivation_consistency(self): """Test that key derivation is consistent""" secret_key = "test_secret" - encrypter1 = SystemOAuthEncrypter(secret_key) - encrypter2 = SystemOAuthEncrypter(secret_key) + encrypter1 = SystemEncrypter(secret_key) + encrypter2 = SystemEncrypter(secret_key) assert encrypter1.key == encrypter2.key @@ -362,9 +362,9 @@ class TestFactoryFunctions: def test_create_system_oauth_encrypter_with_secret(self): """Test factory function with secret key""" secret_key = "test_secret" - encrypter = create_system_oauth_encrypter(secret_key) + encrypter = create_system_encrypter(secret_key) - assert isinstance(encrypter, SystemOAuthEncrypter) + assert isinstance(encrypter, SystemEncrypter) expected_key = hashlib.sha256(secret_key.encode()).digest() assert encrypter.key == expected_key @@ -372,9 +372,9 @@ class TestFactoryFunctions: """Test factory function without secret key""" with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: mock_config.SECRET_KEY = "config_secret" - encrypter = create_system_oauth_encrypter() + encrypter = create_system_encrypter() - assert isinstance(encrypter, SystemOAuthEncrypter) + assert isinstance(encrypter, SystemEncrypter) expected_key = hashlib.sha256(b"config_secret").digest() assert encrypter.key == expected_key @@ -382,9 +382,9 @@ class TestFactoryFunctions: """Test factory function with None secret key""" with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: mock_config.SECRET_KEY = "config_secret" - encrypter = create_system_oauth_encrypter(None) + encrypter = create_system_encrypter(None) - assert isinstance(encrypter, SystemOAuthEncrypter) + assert isinstance(encrypter, SystemEncrypter) expected_key = hashlib.sha256(b"config_secret").digest() assert encrypter.key == expected_key @@ -395,26 +395,26 @@ class TestGlobalEncrypterInstance: def test_get_system_oauth_encrypter_singleton(self): """Test that get_system_oauth_encrypter returns singleton instance""" # Clear the global instance first - import core.tools.utils.system_oauth_encryption + import core.tools.utils.system_encryption - core.tools.utils.system_oauth_encryption._oauth_encrypter = None + core.tools.utils.system_encryption._encrypter = None - encrypter1 = get_system_oauth_encrypter() - encrypter2 = get_system_oauth_encrypter() + encrypter1 = get_system_encrypter() + encrypter2 = get_system_encrypter() assert encrypter1 is encrypter2 - assert isinstance(encrypter1, SystemOAuthEncrypter) + assert isinstance(encrypter1, SystemEncrypter) def test_get_system_oauth_encrypter_uses_config(self): """Test that global encrypter uses config""" # Clear the global instance first - import core.tools.utils.system_oauth_encryption + import core.tools.utils.system_encryption - core.tools.utils.system_oauth_encryption._oauth_encrypter = None + core.tools.utils.system_encryption._encrypter = None with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config: mock_config.SECRET_KEY = "global_secret" - encrypter = get_system_oauth_encrypter() + encrypter = get_system_encrypter() expected_key = hashlib.sha256(b"global_secret").digest() assert encrypter.key == expected_key @@ -427,7 +427,7 @@ class TestConvenienceFunctions: """Test encrypt_system_oauth_params convenience function""" oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} - encrypted = encrypt_system_oauth_params(oauth_params) + encrypted = encrypt_system_params(oauth_params) assert isinstance(encrypted, str) assert len(encrypted) > 0 @@ -436,8 +436,8 @@ class TestConvenienceFunctions: """Test decrypt_system_oauth_params convenience function""" oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} - encrypted = encrypt_system_oauth_params(oauth_params) - decrypted = decrypt_system_oauth_params(encrypted) + encrypted = encrypt_system_params(oauth_params) + decrypted = decrypt_system_params(encrypted) assert decrypted == oauth_params @@ -453,22 +453,22 @@ class TestConvenienceFunctions: ] for original_params in test_cases: - encrypted = encrypt_system_oauth_params(original_params) - decrypted = decrypt_system_oauth_params(encrypted) + encrypted = encrypt_system_params(original_params) + decrypted = decrypt_system_params(encrypted) assert decrypted == original_params, f"Failed for case: {original_params}" def test_convenience_functions_with_errors(self): """Test convenience functions with error conditions""" # Test encryption with invalid input with pytest.raises(Exception): # noqa: B017 - encrypt_system_oauth_params(None) + encrypt_system_params(None) # Test decryption with invalid input with pytest.raises(ValueError): - decrypt_system_oauth_params("") + decrypt_system_params("") with pytest.raises(ValueError): - decrypt_system_oauth_params(None) + decrypt_system_params(None) class TestErrorHandling: @@ -476,14 +476,14 @@ class TestErrorHandling: def test_oauth_encryption_error_inheritance(self): """Test that OAuthEncryptionError is a proper exception""" - error = OAuthEncryptionError("Test error") + error = EncryptionError("Test error") assert isinstance(error, Exception) assert str(error) == "Test error" def test_oauth_encryption_error_with_cause(self): """Test OAuthEncryptionError with cause""" original_error = ValueError("Original error") - error = OAuthEncryptionError("Wrapper error") + error = EncryptionError("Wrapper error") error.__cause__ = original_error assert isinstance(error, Exception) @@ -492,22 +492,22 @@ class TestErrorHandling: def test_error_messages_are_informative(self): """Test that error messages are informative""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") # Test empty string error with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params("") + encrypter.decrypt_params("") assert "encrypted_data cannot be empty" in str(exc_info.value) # Test non-string error with pytest.raises(ValueError) as exc_info: - encrypter.decrypt_oauth_params(123) + encrypter.decrypt_params(123) assert "encrypted_data must be a string" in str(exc_info.value) # Test invalid format error short_data = base64.b64encode(b"short").decode() - with pytest.raises(OAuthEncryptionError) as exc_info: - encrypter.decrypt_oauth_params(short_data) + with pytest.raises(EncryptionError) as exc_info: + encrypter.decrypt_params(short_data) assert "Invalid encrypted data format" in str(exc_info.value) @@ -517,25 +517,25 @@ class TestEdgeCases: def test_very_long_secret_key(self): """Test with very long secret key""" long_secret = "x" * 10000 - encrypter = SystemOAuthEncrypter(long_secret) + encrypter = SystemEncrypter(long_secret) # Key should still be 32 bytes due to SHA-256 assert len(encrypter.key) == 32 # Should still work normally oauth_params = {"client_id": "test_id"} - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypted = encrypter.encrypt_params(oauth_params) + decrypted = encrypter.decrypt_params(encrypted) assert decrypted == oauth_params def test_special_characters_in_secret_key(self): """Test with special characters in secret key""" special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~testπŸš€" - encrypter = SystemOAuthEncrypter(special_secret) + encrypter = SystemEncrypter(special_secret) oauth_params = {"client_id": "test_id"} - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypted = encrypter.encrypt_params(oauth_params) + decrypted = encrypter.decrypt_params(encrypted) assert decrypted == oauth_params def test_empty_values_in_oauth_params(self): @@ -551,18 +551,18 @@ class TestEdgeCases: "none": None, } - encrypter = SystemOAuthEncrypter("test_secret") - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypter = SystemEncrypter("test_secret") + encrypted = encrypter.encrypt_params(oauth_params) + decrypted = encrypter.decrypt_params(encrypted) assert decrypted == oauth_params def test_deeply_nested_oauth_params(self): """Test with deeply nested oauth params""" oauth_params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}} - encrypter = SystemOAuthEncrypter("test_secret") - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypter = SystemEncrypter("test_secret") + encrypted = encrypter.encrypt_params(oauth_params) + decrypted = encrypter.decrypt_params(encrypted) assert decrypted == oauth_params def test_oauth_params_with_all_json_types(self): @@ -579,9 +579,9 @@ class TestEdgeCases: "object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True}, } - encrypter = SystemOAuthEncrypter("test_secret") - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypter = SystemEncrypter("test_secret") + encrypted = encrypter.encrypt_params(oauth_params) + decrypted = encrypter.decrypt_params(encrypted) assert decrypted == oauth_params @@ -593,27 +593,27 @@ class TestPerformance: large_value = "x" * 100000 # 100KB oauth_params = {"client_id": "test_id", "large_data": large_value} - encrypter = SystemOAuthEncrypter("test_secret") - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypter = SystemEncrypter("test_secret") + encrypted = encrypter.encrypt_params(oauth_params) + decrypted = encrypter.decrypt_params(encrypted) assert decrypted == oauth_params def test_many_fields_oauth_params(self): """Test with many fields in oauth params""" oauth_params = {f"field_{i}": f"value_{i}" for i in range(1000)} - encrypter = SystemOAuthEncrypter("test_secret") - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypter = SystemEncrypter("test_secret") + encrypted = encrypter.encrypt_params(oauth_params) + decrypted = encrypter.decrypt_params(encrypted) assert decrypted == oauth_params def test_repeated_encryption_decryption(self): """Test repeated encryption and decryption operations""" - encrypter = SystemOAuthEncrypter("test_secret") + encrypter = SystemEncrypter("test_secret") oauth_params = {"client_id": "test_id", "client_secret": "test_secret"} # Test multiple rounds of encryption/decryption for i in range(100): - encrypted = encrypter.encrypt_oauth_params(oauth_params) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypted = encrypter.encrypt_params(oauth_params) + decrypted = encrypter.decrypt_params(encrypted) assert decrypted == oauth_params