refactor(encryption): using oauth encryption as a general encryption util.

This commit is contained in:
Harry 2026-01-09 16:50:32 +08:00
parent 07ff8df58d
commit 925825a41b
4 changed files with 181 additions and 181 deletions

View File

@ -14,23 +14,23 @@ from configs import dify_config
logger = logging.getLogger(__name__)
class OAuthEncryptionError(Exception):
"""OAuth encryption/decryption specific error"""
class EncryptionError(Exception):
"""Encryption/decryption specific error"""
pass
class SystemOAuthEncrypter:
class SystemEncrypter:
"""
A simple OAuth parameters encrypter using AES-CBC encryption.
A simple parameters encrypter using AES-CBC encryption.
This class provides methods to encrypt and decrypt OAuth parameters
This class provides methods to encrypt and decrypt parameters
using AES-CBC mode with a key derived from the application's SECRET_KEY.
"""
def __init__(self, secret_key: str | None = None):
"""
Initialize the OAuth encrypter.
Initialize the encrypter.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
@ -43,19 +43,19 @@ class SystemOAuthEncrypter:
# Generate a fixed 256-bit key using SHA-256
self.key = hashlib.sha256(secret_key.encode()).digest()
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
def encrypt_params(self, params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters.
Encrypt parameters.
Args:
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
params: parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
Returns:
Base64-encoded encrypted string
Raises:
OAuthEncryptionError: If encryption fails
ValueError: If oauth_params is invalid
EncryptionError: If encryption fails
ValueError: If params is invalid
"""
try:
@ -66,7 +66,7 @@ class SystemOAuthEncrypter:
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Encrypt data
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
# Combine IV and encrypted data
@ -76,20 +76,20 @@ class SystemOAuthEncrypter:
return base64.b64encode(combined).decode()
except Exception as e:
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
raise EncryptionError(f"Encryption failed: {str(e)}") from e
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters.
Decrypt parameters.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Decrypted parameters dictionary
Raises:
OAuthEncryptionError: If decryption fails
EncryptionError: If decryption fails
ValueError: If encrypted_data is invalid
"""
if not isinstance(encrypted_data, str):
@ -118,70 +118,70 @@ class SystemOAuthEncrypter:
unpadded_data = unpad(decrypted_data, AES.block_size)
# Parse JSON
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
if not isinstance(oauth_params, dict):
if not isinstance(params, dict):
raise ValueError("Decrypted data is not a valid dictionary")
return oauth_params
return params
except Exception as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
raise EncryptionError(f"Decryption failed: {str(e)}") from e
# Factory function for creating encrypter instances
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
"""
Create an OAuth encrypter instance.
Create an encrypter instance.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Returns:
SystemOAuthEncrypter instance
SystemEncrypter instance
"""
return SystemOAuthEncrypter(secret_key=secret_key)
return SystemEncrypter(secret_key=secret_key)
# Global encrypter instance (for backward compatibility)
_oauth_encrypter: SystemOAuthEncrypter | None = None
_encrypter: SystemEncrypter | None = None
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
def get_system_encrypter() -> SystemEncrypter:
"""
Get the global OAuth encrypter instance.
Get the global encrypter instance.
Returns:
SystemOAuthEncrypter instance
SystemEncrypter instance
"""
global _oauth_encrypter
if _oauth_encrypter is None:
_oauth_encrypter = SystemOAuthEncrypter()
return _oauth_encrypter
global _encrypter
if _encrypter is None:
_encrypter = SystemEncrypter()
return _encrypter
# Convenience functions for backward compatibility
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
def encrypt_system_params(params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters using the global encrypter.
Encrypt parameters using the global encrypter.
Args:
oauth_params: OAuth parameters dictionary
params: parameters dictionary
Returns:
Base64-encoded encrypted string
"""
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
return get_system_encrypter().encrypt_params(params)
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters using the global encrypter.
Decrypt parameters using the global encrypter.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Decrypted parameters dictionary
"""
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
return get_system_encrypter().decrypt_params(encrypted_data)

View File

@ -26,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import create_provider_encrypter
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.tools.utils.system_encryption import decrypt_system_params
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider_ids import ToolProviderID
@ -502,7 +502,7 @@ class BuiltinToolManageService:
)
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")

View File

@ -14,7 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.tools.utils.system_encryption import decrypt_system_params
from core.trigger.entities.api_entities import (
TriggerProviderApiEntity,
TriggerProviderSubscriptionApiEntity,
@ -591,7 +591,7 @@ class TriggerProviderService:
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")

View File

@ -7,13 +7,13 @@ from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad
from core.tools.utils.system_oauth_encryption import (
OAuthEncryptionError,
SystemOAuthEncrypter,
create_system_oauth_encrypter,
decrypt_system_oauth_params,
encrypt_system_oauth_params,
get_system_oauth_encrypter,
from core.tools.utils.system_encryption import (
EncryptionError,
SystemEncrypter,
create_system_encrypter,
decrypt_system_params,
encrypt_system_params,
get_system_encrypter,
)
@ -23,7 +23,7 @@ class TestSystemOAuthEncrypter:
def test_init_with_secret_key(self):
"""Test initialization with provided secret key"""
secret_key = "test_secret_key"
encrypter = SystemOAuthEncrypter(secret_key=secret_key)
encrypter = SystemEncrypter(secret_key=secret_key)
expected_key = hashlib.sha256(secret_key.encode()).digest()
assert encrypter.key == expected_key
@ -31,13 +31,13 @@ class TestSystemOAuthEncrypter:
"""Test initialization with None secret key falls back to config"""
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "config_secret"
encrypter = SystemOAuthEncrypter(secret_key=None)
encrypter = SystemEncrypter(secret_key=None)
expected_key = hashlib.sha256(b"config_secret").digest()
assert encrypter.key == expected_key
def test_init_with_empty_secret_key(self):
"""Test initialization with empty secret key"""
encrypter = SystemOAuthEncrypter(secret_key="")
encrypter = SystemEncrypter(secret_key="")
expected_key = hashlib.sha256(b"").digest()
assert encrypter.key == expected_key
@ -45,16 +45,16 @@ class TestSystemOAuthEncrypter:
"""Test initialization without secret key uses config"""
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "default_secret"
encrypter = SystemOAuthEncrypter()
encrypter = SystemEncrypter()
expected_key = hashlib.sha256(b"default_secret").digest()
assert encrypter.key == expected_key
def test_encrypt_oauth_params_basic(self):
"""Test basic OAuth parameters encryption"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
encrypted = encrypter.encrypt_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
@ -66,16 +66,16 @@ class TestSystemOAuthEncrypter:
def test_encrypt_oauth_params_empty_dict(self):
"""Test encryption with empty dictionary"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
oauth_params = {}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
encrypted = encrypter.encrypt_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_oauth_params_complex_data(self):
"""Test encryption with complex data structures"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
oauth_params = {
"client_id": "test_id",
"client_secret": "test_secret",
@ -86,64 +86,64 @@ class TestSystemOAuthEncrypter:
"null_value": None,
}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
encrypted = encrypter.encrypt_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_oauth_params_unicode_data(self):
"""Test encryption with unicode data"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case 🚀"}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
encrypted = encrypter.encrypt_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_oauth_params_large_data(self):
"""Test encryption with large data"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
oauth_params = {
"client_id": "test_id",
"large_data": "x" * 10000, # 10KB of data
}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
encrypted = encrypter.encrypt_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_oauth_params_invalid_input(self):
"""Test encryption with invalid input types"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
with pytest.raises(Exception): # noqa: B017
encrypter.encrypt_oauth_params(None)
encrypter.encrypt_params(None)
with pytest.raises(Exception): # noqa: B017
encrypter.encrypt_oauth_params("not_a_dict")
encrypter.encrypt_params("not_a_dict")
def test_decrypt_oauth_params_basic(self):
"""Test basic OAuth parameters decryption"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
encrypted = encrypter.encrypt_params(original_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_empty_dict(self):
"""Test decryption of empty dictionary"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
original_params = {}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
encrypted = encrypter.encrypt_params(original_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_complex_data(self):
"""Test decryption with complex data structures"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
original_params = {
"client_id": "test_id",
"client_secret": "test_secret",
@ -154,104 +154,104 @@ class TestSystemOAuthEncrypter:
"null_value": None,
}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
encrypted = encrypter.encrypt_params(original_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_unicode_data(self):
"""Test decryption with unicode data"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
original_params = {
"client_id": "test_id",
"client_secret": "test_secret",
"description": "This is a test case 🚀",
}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
encrypted = encrypter.encrypt_params(original_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_large_data(self):
"""Test decryption with large data"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
original_params = {
"client_id": "test_id",
"large_data": "x" * 10000, # 10KB of data
}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
encrypted = encrypter.encrypt_params(original_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_invalid_base64(self):
"""Test decryption with invalid base64 data"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
with pytest.raises(OAuthEncryptionError):
encrypter.decrypt_oauth_params("invalid_base64!")
with pytest.raises(EncryptionError):
encrypter.decrypt_params("invalid_base64!")
def test_decrypt_oauth_params_empty_string(self):
"""Test decryption with empty string"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params("")
encrypter.decrypt_params("")
assert "encrypted_data cannot be empty" in str(exc_info.value)
def test_decrypt_oauth_params_non_string_input(self):
"""Test decryption with non-string input"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(123)
encrypter.decrypt_params(123)
assert "encrypted_data must be a string" in str(exc_info.value)
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(None)
encrypter.decrypt_params(None)
assert "encrypted_data must be a string" in str(exc_info.value)
def test_decrypt_oauth_params_too_short_data(self):
"""Test decryption with too short encrypted data"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
# Create data that's too short (less than 32 bytes)
short_data = base64.b64encode(b"short").decode()
with pytest.raises(OAuthEncryptionError) as exc_info:
encrypter.decrypt_oauth_params(short_data)
with pytest.raises(EncryptionError) as exc_info:
encrypter.decrypt_params(short_data)
assert "Invalid encrypted data format" in str(exc_info.value)
def test_decrypt_oauth_params_corrupted_data(self):
"""Test decryption with corrupted data"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
# Create corrupted data (valid base64 but invalid encrypted content)
corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage
with pytest.raises(OAuthEncryptionError):
encrypter.decrypt_oauth_params(corrupted_data)
with pytest.raises(EncryptionError):
encrypter.decrypt_params(corrupted_data)
def test_decrypt_oauth_params_wrong_key(self):
"""Test decryption with wrong key"""
encrypter1 = SystemOAuthEncrypter("secret1")
encrypter2 = SystemOAuthEncrypter("secret2")
encrypter1 = SystemEncrypter("secret1")
encrypter2 = SystemEncrypter("secret2")
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypter1.encrypt_oauth_params(original_params)
encrypted = encrypter1.encrypt_params(original_params)
with pytest.raises(OAuthEncryptionError):
encrypter2.decrypt_oauth_params(encrypted)
with pytest.raises(EncryptionError):
encrypter2.decrypt_params(encrypted)
def test_encryption_decryption_consistency(self):
"""Test that encryption and decryption are consistent"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
test_cases = [
{},
@ -264,42 +264,42 @@ class TestSystemOAuthEncrypter:
]
for original_params in test_cases:
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
encrypted = encrypter.encrypt_params(original_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == original_params, f"Failed for case: {original_params}"
def test_encryption_randomness(self):
"""Test that encryption produces different results for same input"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted1 = encrypter.encrypt_oauth_params(oauth_params)
encrypted2 = encrypter.encrypt_oauth_params(oauth_params)
encrypted1 = encrypter.encrypt_params(oauth_params)
encrypted2 = encrypter.encrypt_params(oauth_params)
# Should be different due to random IV
assert encrypted1 != encrypted2
# But should decrypt to same result
decrypted1 = encrypter.decrypt_oauth_params(encrypted1)
decrypted2 = encrypter.decrypt_oauth_params(encrypted2)
decrypted1 = encrypter.decrypt_params(encrypted1)
decrypted2 = encrypter.decrypt_params(encrypted2)
assert decrypted1 == decrypted2 == oauth_params
def test_different_secret_keys_produce_different_results(self):
"""Test that different secret keys produce different encrypted results"""
encrypter1 = SystemOAuthEncrypter("secret1")
encrypter2 = SystemOAuthEncrypter("secret2")
encrypter1 = SystemEncrypter("secret1")
encrypter2 = SystemEncrypter("secret2")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted1 = encrypter1.encrypt_oauth_params(oauth_params)
encrypted2 = encrypter2.encrypt_oauth_params(oauth_params)
encrypted1 = encrypter1.encrypt_params(oauth_params)
encrypted2 = encrypter2.encrypt_params(oauth_params)
# Should produce different encrypted results
assert encrypted1 != encrypted2
# But each should decrypt correctly with its own key
decrypted1 = encrypter1.decrypt_oauth_params(encrypted1)
decrypted2 = encrypter2.decrypt_oauth_params(encrypted2)
decrypted1 = encrypter1.decrypt_params(encrypted1)
decrypted2 = encrypter2.decrypt_params(encrypted2)
assert decrypted1 == decrypted2 == oauth_params
@patch("core.tools.utils.system_oauth_encryption.get_random_bytes")
@ -307,11 +307,11 @@ class TestSystemOAuthEncrypter:
"""Test encryption when crypto operation fails"""
mock_get_random_bytes.side_effect = Exception("Crypto error")
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
oauth_params = {"client_id": "test_id"}
with pytest.raises(OAuthEncryptionError) as exc_info:
encrypter.encrypt_oauth_params(oauth_params)
with pytest.raises(EncryptionError) as exc_info:
encrypter.encrypt_params(oauth_params)
assert "Encryption failed" in str(exc_info.value)
@ -320,17 +320,17 @@ class TestSystemOAuthEncrypter:
"""Test encryption when JSON serialization fails"""
mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error")
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
oauth_params = {"client_id": "test_id"}
with pytest.raises(OAuthEncryptionError) as exc_info:
encrypter.encrypt_oauth_params(oauth_params)
with pytest.raises(EncryptionError) as exc_info:
encrypter.encrypt_params(oauth_params)
assert "Encryption failed" in str(exc_info.value)
def test_decrypt_oauth_params_invalid_json(self):
"""Test decryption with invalid JSON data"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
# Create valid encrypted data but with invalid JSON content
iv = get_random_bytes(16)
@ -341,14 +341,14 @@ class TestSystemOAuthEncrypter:
combined = iv + encrypted_data
encoded = base64.b64encode(combined).decode()
with pytest.raises(OAuthEncryptionError):
encrypter.decrypt_oauth_params(encoded)
with pytest.raises(EncryptionError):
encrypter.decrypt_params(encoded)
def test_key_derivation_consistency(self):
"""Test that key derivation is consistent"""
secret_key = "test_secret"
encrypter1 = SystemOAuthEncrypter(secret_key)
encrypter2 = SystemOAuthEncrypter(secret_key)
encrypter1 = SystemEncrypter(secret_key)
encrypter2 = SystemEncrypter(secret_key)
assert encrypter1.key == encrypter2.key
@ -362,9 +362,9 @@ class TestFactoryFunctions:
def test_create_system_oauth_encrypter_with_secret(self):
"""Test factory function with secret key"""
secret_key = "test_secret"
encrypter = create_system_oauth_encrypter(secret_key)
encrypter = create_system_encrypter(secret_key)
assert isinstance(encrypter, SystemOAuthEncrypter)
assert isinstance(encrypter, SystemEncrypter)
expected_key = hashlib.sha256(secret_key.encode()).digest()
assert encrypter.key == expected_key
@ -372,9 +372,9 @@ class TestFactoryFunctions:
"""Test factory function without secret key"""
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "config_secret"
encrypter = create_system_oauth_encrypter()
encrypter = create_system_encrypter()
assert isinstance(encrypter, SystemOAuthEncrypter)
assert isinstance(encrypter, SystemEncrypter)
expected_key = hashlib.sha256(b"config_secret").digest()
assert encrypter.key == expected_key
@ -382,9 +382,9 @@ class TestFactoryFunctions:
"""Test factory function with None secret key"""
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "config_secret"
encrypter = create_system_oauth_encrypter(None)
encrypter = create_system_encrypter(None)
assert isinstance(encrypter, SystemOAuthEncrypter)
assert isinstance(encrypter, SystemEncrypter)
expected_key = hashlib.sha256(b"config_secret").digest()
assert encrypter.key == expected_key
@ -395,26 +395,26 @@ class TestGlobalEncrypterInstance:
def test_get_system_oauth_encrypter_singleton(self):
"""Test that get_system_oauth_encrypter returns singleton instance"""
# Clear the global instance first
import core.tools.utils.system_oauth_encryption
import core.tools.utils.system_encryption
core.tools.utils.system_oauth_encryption._oauth_encrypter = None
core.tools.utils.system_encryption._encrypter = None
encrypter1 = get_system_oauth_encrypter()
encrypter2 = get_system_oauth_encrypter()
encrypter1 = get_system_encrypter()
encrypter2 = get_system_encrypter()
assert encrypter1 is encrypter2
assert isinstance(encrypter1, SystemOAuthEncrypter)
assert isinstance(encrypter1, SystemEncrypter)
def test_get_system_oauth_encrypter_uses_config(self):
"""Test that global encrypter uses config"""
# Clear the global instance first
import core.tools.utils.system_oauth_encryption
import core.tools.utils.system_encryption
core.tools.utils.system_oauth_encryption._oauth_encrypter = None
core.tools.utils.system_encryption._encrypter = None
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "global_secret"
encrypter = get_system_oauth_encrypter()
encrypter = get_system_encrypter()
expected_key = hashlib.sha256(b"global_secret").digest()
assert encrypter.key == expected_key
@ -427,7 +427,7 @@ class TestConvenienceFunctions:
"""Test encrypt_system_oauth_params convenience function"""
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypt_system_oauth_params(oauth_params)
encrypted = encrypt_system_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
@ -436,8 +436,8 @@ class TestConvenienceFunctions:
"""Test decrypt_system_oauth_params convenience function"""
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypt_system_oauth_params(oauth_params)
decrypted = decrypt_system_oauth_params(encrypted)
encrypted = encrypt_system_params(oauth_params)
decrypted = decrypt_system_params(encrypted)
assert decrypted == oauth_params
@ -453,22 +453,22 @@ class TestConvenienceFunctions:
]
for original_params in test_cases:
encrypted = encrypt_system_oauth_params(original_params)
decrypted = decrypt_system_oauth_params(encrypted)
encrypted = encrypt_system_params(original_params)
decrypted = decrypt_system_params(encrypted)
assert decrypted == original_params, f"Failed for case: {original_params}"
def test_convenience_functions_with_errors(self):
"""Test convenience functions with error conditions"""
# Test encryption with invalid input
with pytest.raises(Exception): # noqa: B017
encrypt_system_oauth_params(None)
encrypt_system_params(None)
# Test decryption with invalid input
with pytest.raises(ValueError):
decrypt_system_oauth_params("")
decrypt_system_params("")
with pytest.raises(ValueError):
decrypt_system_oauth_params(None)
decrypt_system_params(None)
class TestErrorHandling:
@ -476,14 +476,14 @@ class TestErrorHandling:
def test_oauth_encryption_error_inheritance(self):
"""Test that OAuthEncryptionError is a proper exception"""
error = OAuthEncryptionError("Test error")
error = EncryptionError("Test error")
assert isinstance(error, Exception)
assert str(error) == "Test error"
def test_oauth_encryption_error_with_cause(self):
"""Test OAuthEncryptionError with cause"""
original_error = ValueError("Original error")
error = OAuthEncryptionError("Wrapper error")
error = EncryptionError("Wrapper error")
error.__cause__ = original_error
assert isinstance(error, Exception)
@ -492,22 +492,22 @@ class TestErrorHandling:
def test_error_messages_are_informative(self):
"""Test that error messages are informative"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
# Test empty string error
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params("")
encrypter.decrypt_params("")
assert "encrypted_data cannot be empty" in str(exc_info.value)
# Test non-string error
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(123)
encrypter.decrypt_params(123)
assert "encrypted_data must be a string" in str(exc_info.value)
# Test invalid format error
short_data = base64.b64encode(b"short").decode()
with pytest.raises(OAuthEncryptionError) as exc_info:
encrypter.decrypt_oauth_params(short_data)
with pytest.raises(EncryptionError) as exc_info:
encrypter.decrypt_params(short_data)
assert "Invalid encrypted data format" in str(exc_info.value)
@ -517,25 +517,25 @@ class TestEdgeCases:
def test_very_long_secret_key(self):
"""Test with very long secret key"""
long_secret = "x" * 10000
encrypter = SystemOAuthEncrypter(long_secret)
encrypter = SystemEncrypter(long_secret)
# Key should still be 32 bytes due to SHA-256
assert len(encrypter.key) == 32
# Should still work normally
oauth_params = {"client_id": "test_id"}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
encrypted = encrypter.encrypt_params(oauth_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == oauth_params
def test_special_characters_in_secret_key(self):
"""Test with special characters in secret key"""
special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~test🚀"
encrypter = SystemOAuthEncrypter(special_secret)
encrypter = SystemEncrypter(special_secret)
oauth_params = {"client_id": "test_id"}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
encrypted = encrypter.encrypt_params(oauth_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == oauth_params
def test_empty_values_in_oauth_params(self):
@ -551,18 +551,18 @@ class TestEdgeCases:
"none": None,
}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
encrypter = SystemEncrypter("test_secret")
encrypted = encrypter.encrypt_params(oauth_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == oauth_params
def test_deeply_nested_oauth_params(self):
"""Test with deeply nested oauth params"""
oauth_params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
encrypter = SystemEncrypter("test_secret")
encrypted = encrypter.encrypt_params(oauth_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == oauth_params
def test_oauth_params_with_all_json_types(self):
@ -579,9 +579,9 @@ class TestEdgeCases:
"object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True},
}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
encrypter = SystemEncrypter("test_secret")
encrypted = encrypter.encrypt_params(oauth_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == oauth_params
@ -593,27 +593,27 @@ class TestPerformance:
large_value = "x" * 100000 # 100KB
oauth_params = {"client_id": "test_id", "large_data": large_value}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
encrypter = SystemEncrypter("test_secret")
encrypted = encrypter.encrypt_params(oauth_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == oauth_params
def test_many_fields_oauth_params(self):
"""Test with many fields in oauth params"""
oauth_params = {f"field_{i}": f"value_{i}" for i in range(1000)}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
encrypter = SystemEncrypter("test_secret")
encrypted = encrypter.encrypt_params(oauth_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == oauth_params
def test_repeated_encryption_decryption(self):
"""Test repeated encryption and decryption operations"""
encrypter = SystemOAuthEncrypter("test_secret")
encrypter = SystemEncrypter("test_secret")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
# Test multiple rounds of encryption/decryption
for i in range(100):
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
encrypted = encrypter.encrypt_params(oauth_params)
decrypted = encrypter.decrypt_params(encrypted)
assert decrypted == oauth_params