mirror of
https://github.com/langgenius/dify.git
synced 2026-01-14 06:07:33 +08:00
refactor(encryption): using oauth encryption as a general encryption util.
This commit is contained in:
parent
07ff8df58d
commit
925825a41b
@ -14,23 +14,23 @@ from configs import dify_config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthEncryptionError(Exception):
|
||||
"""OAuth encryption/decryption specific error"""
|
||||
class EncryptionError(Exception):
|
||||
"""Encryption/decryption specific error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SystemOAuthEncrypter:
|
||||
class SystemEncrypter:
|
||||
"""
|
||||
A simple OAuth parameters encrypter using AES-CBC encryption.
|
||||
A simple parameters encrypter using AES-CBC encryption.
|
||||
|
||||
This class provides methods to encrypt and decrypt OAuth parameters
|
||||
This class provides methods to encrypt and decrypt parameters
|
||||
using AES-CBC mode with a key derived from the application's SECRET_KEY.
|
||||
"""
|
||||
|
||||
def __init__(self, secret_key: str | None = None):
|
||||
"""
|
||||
Initialize the OAuth encrypter.
|
||||
Initialize the encrypter.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
@ -43,19 +43,19 @@ class SystemOAuthEncrypter:
|
||||
# Generate a fixed 256-bit key using SHA-256
|
||||
self.key = hashlib.sha256(secret_key.encode()).digest()
|
||||
|
||||
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
|
||||
def encrypt_params(self, params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters.
|
||||
Encrypt parameters.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
params: parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If encryption fails
|
||||
ValueError: If oauth_params is invalid
|
||||
EncryptionError: If encryption fails
|
||||
ValueError: If params is invalid
|
||||
"""
|
||||
|
||||
try:
|
||||
@ -66,7 +66,7 @@ class SystemOAuthEncrypter:
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Encrypt data
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
|
||||
encrypted_data = cipher.encrypt(padded_data)
|
||||
|
||||
# Combine IV and encrypted data
|
||||
@ -76,20 +76,20 @@ class SystemOAuthEncrypter:
|
||||
return base64.b64encode(combined).decode()
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
raise EncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
|
||||
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters.
|
||||
Decrypt parameters.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
Decrypted parameters dictionary
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If decryption fails
|
||||
EncryptionError: If decryption fails
|
||||
ValueError: If encrypted_data is invalid
|
||||
"""
|
||||
if not isinstance(encrypted_data, str):
|
||||
@ -118,70 +118,70 @@ class SystemOAuthEncrypter:
|
||||
unpadded_data = unpad(decrypted_data, AES.block_size)
|
||||
|
||||
# Parse JSON
|
||||
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
|
||||
if not isinstance(oauth_params, dict):
|
||||
if not isinstance(params, dict):
|
||||
raise ValueError("Decrypted data is not a valid dictionary")
|
||||
|
||||
return oauth_params
|
||||
return params
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
raise EncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
|
||||
|
||||
# Factory function for creating encrypter instances
|
||||
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
|
||||
def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
|
||||
"""
|
||||
Create an OAuth encrypter instance.
|
||||
Create an encrypter instance.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
SystemEncrypter instance
|
||||
"""
|
||||
return SystemOAuthEncrypter(secret_key=secret_key)
|
||||
return SystemEncrypter(secret_key=secret_key)
|
||||
|
||||
|
||||
# Global encrypter instance (for backward compatibility)
|
||||
_oauth_encrypter: SystemOAuthEncrypter | None = None
|
||||
_encrypter: SystemEncrypter | None = None
|
||||
|
||||
|
||||
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
|
||||
def get_system_encrypter() -> SystemEncrypter:
|
||||
"""
|
||||
Get the global OAuth encrypter instance.
|
||||
Get the global encrypter instance.
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
SystemEncrypter instance
|
||||
"""
|
||||
global _oauth_encrypter
|
||||
if _oauth_encrypter is None:
|
||||
_oauth_encrypter = SystemOAuthEncrypter()
|
||||
return _oauth_encrypter
|
||||
global _encrypter
|
||||
if _encrypter is None:
|
||||
_encrypter = SystemEncrypter()
|
||||
return _encrypter
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
|
||||
def encrypt_system_params(params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters using the global encrypter.
|
||||
Encrypt parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary
|
||||
params: parameters dictionary
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
"""
|
||||
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
|
||||
return get_system_encrypter().encrypt_params(params)
|
||||
|
||||
|
||||
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters using the global encrypter.
|
||||
Decrypt parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
Decrypted parameters dictionary
|
||||
"""
|
||||
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
|
||||
return get_system_encrypter().decrypt_params(encrypted_data)
|
||||
@ -26,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from core.tools.utils.system_encryption import decrypt_system_params
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import ToolProviderID
|
||||
@ -502,7 +502,7 @@ class BuiltinToolManageService:
|
||||
)
|
||||
if system_client:
|
||||
try:
|
||||
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
|
||||
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error decrypting system oauth params: {e}")
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from core.tools.utils.system_encryption import decrypt_system_params
|
||||
from core.trigger.entities.api_entities import (
|
||||
TriggerProviderApiEntity,
|
||||
TriggerProviderSubscriptionApiEntity,
|
||||
@ -591,7 +591,7 @@ class TriggerProviderService:
|
||||
|
||||
if system_client:
|
||||
try:
|
||||
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
|
||||
oauth_params = decrypt_system_params(system_client.encrypted_oauth_params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error decrypting system oauth params: {e}")
|
||||
|
||||
|
||||
@ -7,13 +7,13 @@ from Crypto.Cipher import AES
|
||||
from Crypto.Random import get_random_bytes
|
||||
from Crypto.Util.Padding import pad
|
||||
|
||||
from core.tools.utils.system_oauth_encryption import (
|
||||
OAuthEncryptionError,
|
||||
SystemOAuthEncrypter,
|
||||
create_system_oauth_encrypter,
|
||||
decrypt_system_oauth_params,
|
||||
encrypt_system_oauth_params,
|
||||
get_system_oauth_encrypter,
|
||||
from core.tools.utils.system_encryption import (
|
||||
EncryptionError,
|
||||
SystemEncrypter,
|
||||
create_system_encrypter,
|
||||
decrypt_system_params,
|
||||
encrypt_system_params,
|
||||
get_system_encrypter,
|
||||
)
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ class TestSystemOAuthEncrypter:
|
||||
def test_init_with_secret_key(self):
|
||||
"""Test initialization with provided secret key"""
|
||||
secret_key = "test_secret_key"
|
||||
encrypter = SystemOAuthEncrypter(secret_key=secret_key)
|
||||
encrypter = SystemEncrypter(secret_key=secret_key)
|
||||
expected_key = hashlib.sha256(secret_key.encode()).digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
@ -31,13 +31,13 @@ class TestSystemOAuthEncrypter:
|
||||
"""Test initialization with None secret key falls back to config"""
|
||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "config_secret"
|
||||
encrypter = SystemOAuthEncrypter(secret_key=None)
|
||||
encrypter = SystemEncrypter(secret_key=None)
|
||||
expected_key = hashlib.sha256(b"config_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_init_with_empty_secret_key(self):
|
||||
"""Test initialization with empty secret key"""
|
||||
encrypter = SystemOAuthEncrypter(secret_key="")
|
||||
encrypter = SystemEncrypter(secret_key="")
|
||||
expected_key = hashlib.sha256(b"").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
@ -45,16 +45,16 @@ class TestSystemOAuthEncrypter:
|
||||
"""Test initialization without secret key uses config"""
|
||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "default_secret"
|
||||
encrypter = SystemOAuthEncrypter()
|
||||
encrypter = SystemEncrypter()
|
||||
expected_key = hashlib.sha256(b"default_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_encrypt_oauth_params_basic(self):
|
||||
"""Test basic OAuth parameters encryption"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
encrypted = encrypter.encrypt_params(oauth_params)
|
||||
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
@ -66,16 +66,16 @@ class TestSystemOAuthEncrypter:
|
||||
|
||||
def test_encrypt_oauth_params_empty_dict(self):
|
||||
"""Test encryption with empty dictionary"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
oauth_params = {}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
encrypted = encrypter.encrypt_params(oauth_params)
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_encrypt_oauth_params_complex_data(self):
|
||||
"""Test encryption with complex data structures"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
oauth_params = {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
@ -86,64 +86,64 @@ class TestSystemOAuthEncrypter:
|
||||
"null_value": None,
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
encrypted = encrypter.encrypt_params(oauth_params)
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_encrypt_oauth_params_unicode_data(self):
|
||||
"""Test encryption with unicode data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case 🚀"}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
encrypted = encrypter.encrypt_params(oauth_params)
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_encrypt_oauth_params_large_data(self):
|
||||
"""Test encryption with large data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
oauth_params = {
|
||||
"client_id": "test_id",
|
||||
"large_data": "x" * 10000, # 10KB of data
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
encrypted = encrypter.encrypt_params(oauth_params)
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_encrypt_oauth_params_invalid_input(self):
|
||||
"""Test encryption with invalid input types"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypter.encrypt_oauth_params(None)
|
||||
encrypter.encrypt_params(None)
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypter.encrypt_oauth_params("not_a_dict")
|
||||
encrypter.encrypt_params("not_a_dict")
|
||||
|
||||
def test_decrypt_oauth_params_basic(self):
|
||||
"""Test basic OAuth parameters decryption"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
encrypted = encrypter.encrypt_params(original_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_oauth_params_empty_dict(self):
|
||||
"""Test decryption of empty dictionary"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
original_params = {}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
encrypted = encrypter.encrypt_params(original_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_oauth_params_complex_data(self):
|
||||
"""Test decryption with complex data structures"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
original_params = {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
@ -154,104 +154,104 @@ class TestSystemOAuthEncrypter:
|
||||
"null_value": None,
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
encrypted = encrypter.encrypt_params(original_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_oauth_params_unicode_data(self):
|
||||
"""Test decryption with unicode data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
original_params = {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"description": "This is a test case 🚀",
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
encrypted = encrypter.encrypt_params(original_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_oauth_params_large_data(self):
|
||||
"""Test decryption with large data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
original_params = {
|
||||
"client_id": "test_id",
|
||||
"large_data": "x" * 10000, # 10KB of data
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
encrypted = encrypter.encrypt_params(original_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_oauth_params_invalid_base64(self):
|
||||
"""Test decryption with invalid base64 data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(OAuthEncryptionError):
|
||||
encrypter.decrypt_oauth_params("invalid_base64!")
|
||||
with pytest.raises(EncryptionError):
|
||||
encrypter.decrypt_params("invalid_base64!")
|
||||
|
||||
def test_decrypt_oauth_params_empty_string(self):
|
||||
"""Test decryption with empty string"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params("")
|
||||
encrypter.decrypt_params("")
|
||||
|
||||
assert "encrypted_data cannot be empty" in str(exc_info.value)
|
||||
|
||||
def test_decrypt_oauth_params_non_string_input(self):
|
||||
"""Test decryption with non-string input"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(123)
|
||||
encrypter.decrypt_params(123)
|
||||
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(None)
|
||||
encrypter.decrypt_params(None)
|
||||
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
def test_decrypt_oauth_params_too_short_data(self):
|
||||
"""Test decryption with too short encrypted data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
# Create data that's too short (less than 32 bytes)
|
||||
short_data = base64.b64encode(b"short").decode()
|
||||
|
||||
with pytest.raises(OAuthEncryptionError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(short_data)
|
||||
with pytest.raises(EncryptionError) as exc_info:
|
||||
encrypter.decrypt_params(short_data)
|
||||
|
||||
assert "Invalid encrypted data format" in str(exc_info.value)
|
||||
|
||||
def test_decrypt_oauth_params_corrupted_data(self):
|
||||
"""Test decryption with corrupted data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
# Create corrupted data (valid base64 but invalid encrypted content)
|
||||
corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage
|
||||
|
||||
with pytest.raises(OAuthEncryptionError):
|
||||
encrypter.decrypt_oauth_params(corrupted_data)
|
||||
with pytest.raises(EncryptionError):
|
||||
encrypter.decrypt_params(corrupted_data)
|
||||
|
||||
def test_decrypt_oauth_params_wrong_key(self):
|
||||
"""Test decryption with wrong key"""
|
||||
encrypter1 = SystemOAuthEncrypter("secret1")
|
||||
encrypter2 = SystemOAuthEncrypter("secret2")
|
||||
encrypter1 = SystemEncrypter("secret1")
|
||||
encrypter2 = SystemEncrypter("secret2")
|
||||
|
||||
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
encrypted = encrypter1.encrypt_oauth_params(original_params)
|
||||
encrypted = encrypter1.encrypt_params(original_params)
|
||||
|
||||
with pytest.raises(OAuthEncryptionError):
|
||||
encrypter2.decrypt_oauth_params(encrypted)
|
||||
with pytest.raises(EncryptionError):
|
||||
encrypter2.decrypt_params(encrypted)
|
||||
|
||||
def test_encryption_decryption_consistency(self):
|
||||
"""Test that encryption and decryption are consistent"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
test_cases = [
|
||||
{},
|
||||
@ -264,42 +264,42 @@ class TestSystemOAuthEncrypter:
|
||||
]
|
||||
|
||||
for original_params in test_cases:
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
encrypted = encrypter.encrypt_params(original_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == original_params, f"Failed for case: {original_params}"
|
||||
|
||||
def test_encryption_randomness(self):
|
||||
"""Test that encryption produces different results for same input"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted1 = encrypter.encrypt_oauth_params(oauth_params)
|
||||
encrypted2 = encrypter.encrypt_oauth_params(oauth_params)
|
||||
encrypted1 = encrypter.encrypt_params(oauth_params)
|
||||
encrypted2 = encrypter.encrypt_params(oauth_params)
|
||||
|
||||
# Should be different due to random IV
|
||||
assert encrypted1 != encrypted2
|
||||
|
||||
# But should decrypt to same result
|
||||
decrypted1 = encrypter.decrypt_oauth_params(encrypted1)
|
||||
decrypted2 = encrypter.decrypt_oauth_params(encrypted2)
|
||||
decrypted1 = encrypter.decrypt_params(encrypted1)
|
||||
decrypted2 = encrypter.decrypt_params(encrypted2)
|
||||
assert decrypted1 == decrypted2 == oauth_params
|
||||
|
||||
def test_different_secret_keys_produce_different_results(self):
|
||||
"""Test that different secret keys produce different encrypted results"""
|
||||
encrypter1 = SystemOAuthEncrypter("secret1")
|
||||
encrypter2 = SystemOAuthEncrypter("secret2")
|
||||
encrypter1 = SystemEncrypter("secret1")
|
||||
encrypter2 = SystemEncrypter("secret2")
|
||||
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted1 = encrypter1.encrypt_oauth_params(oauth_params)
|
||||
encrypted2 = encrypter2.encrypt_oauth_params(oauth_params)
|
||||
encrypted1 = encrypter1.encrypt_params(oauth_params)
|
||||
encrypted2 = encrypter2.encrypt_params(oauth_params)
|
||||
|
||||
# Should produce different encrypted results
|
||||
assert encrypted1 != encrypted2
|
||||
|
||||
# But each should decrypt correctly with its own key
|
||||
decrypted1 = encrypter1.decrypt_oauth_params(encrypted1)
|
||||
decrypted2 = encrypter2.decrypt_oauth_params(encrypted2)
|
||||
decrypted1 = encrypter1.decrypt_params(encrypted1)
|
||||
decrypted2 = encrypter2.decrypt_params(encrypted2)
|
||||
assert decrypted1 == decrypted2 == oauth_params
|
||||
|
||||
@patch("core.tools.utils.system_oauth_encryption.get_random_bytes")
|
||||
@ -307,11 +307,11 @@ class TestSystemOAuthEncrypter:
|
||||
"""Test encryption when crypto operation fails"""
|
||||
mock_get_random_bytes.side_effect = Exception("Crypto error")
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id"}
|
||||
|
||||
with pytest.raises(OAuthEncryptionError) as exc_info:
|
||||
encrypter.encrypt_oauth_params(oauth_params)
|
||||
with pytest.raises(EncryptionError) as exc_info:
|
||||
encrypter.encrypt_params(oauth_params)
|
||||
|
||||
assert "Encryption failed" in str(exc_info.value)
|
||||
|
||||
@ -320,17 +320,17 @@ class TestSystemOAuthEncrypter:
|
||||
"""Test encryption when JSON serialization fails"""
|
||||
mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error")
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id"}
|
||||
|
||||
with pytest.raises(OAuthEncryptionError) as exc_info:
|
||||
encrypter.encrypt_oauth_params(oauth_params)
|
||||
with pytest.raises(EncryptionError) as exc_info:
|
||||
encrypter.encrypt_params(oauth_params)
|
||||
|
||||
assert "Encryption failed" in str(exc_info.value)
|
||||
|
||||
def test_decrypt_oauth_params_invalid_json(self):
|
||||
"""Test decryption with invalid JSON data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
# Create valid encrypted data but with invalid JSON content
|
||||
iv = get_random_bytes(16)
|
||||
@ -341,14 +341,14 @@ class TestSystemOAuthEncrypter:
|
||||
combined = iv + encrypted_data
|
||||
encoded = base64.b64encode(combined).decode()
|
||||
|
||||
with pytest.raises(OAuthEncryptionError):
|
||||
encrypter.decrypt_oauth_params(encoded)
|
||||
with pytest.raises(EncryptionError):
|
||||
encrypter.decrypt_params(encoded)
|
||||
|
||||
def test_key_derivation_consistency(self):
|
||||
"""Test that key derivation is consistent"""
|
||||
secret_key = "test_secret"
|
||||
encrypter1 = SystemOAuthEncrypter(secret_key)
|
||||
encrypter2 = SystemOAuthEncrypter(secret_key)
|
||||
encrypter1 = SystemEncrypter(secret_key)
|
||||
encrypter2 = SystemEncrypter(secret_key)
|
||||
|
||||
assert encrypter1.key == encrypter2.key
|
||||
|
||||
@ -362,9 +362,9 @@ class TestFactoryFunctions:
|
||||
def test_create_system_oauth_encrypter_with_secret(self):
|
||||
"""Test factory function with secret key"""
|
||||
secret_key = "test_secret"
|
||||
encrypter = create_system_oauth_encrypter(secret_key)
|
||||
encrypter = create_system_encrypter(secret_key)
|
||||
|
||||
assert isinstance(encrypter, SystemOAuthEncrypter)
|
||||
assert isinstance(encrypter, SystemEncrypter)
|
||||
expected_key = hashlib.sha256(secret_key.encode()).digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
@ -372,9 +372,9 @@ class TestFactoryFunctions:
|
||||
"""Test factory function without secret key"""
|
||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "config_secret"
|
||||
encrypter = create_system_oauth_encrypter()
|
||||
encrypter = create_system_encrypter()
|
||||
|
||||
assert isinstance(encrypter, SystemOAuthEncrypter)
|
||||
assert isinstance(encrypter, SystemEncrypter)
|
||||
expected_key = hashlib.sha256(b"config_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
@ -382,9 +382,9 @@ class TestFactoryFunctions:
|
||||
"""Test factory function with None secret key"""
|
||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "config_secret"
|
||||
encrypter = create_system_oauth_encrypter(None)
|
||||
encrypter = create_system_encrypter(None)
|
||||
|
||||
assert isinstance(encrypter, SystemOAuthEncrypter)
|
||||
assert isinstance(encrypter, SystemEncrypter)
|
||||
expected_key = hashlib.sha256(b"config_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
@ -395,26 +395,26 @@ class TestGlobalEncrypterInstance:
|
||||
def test_get_system_oauth_encrypter_singleton(self):
|
||||
"""Test that get_system_oauth_encrypter returns singleton instance"""
|
||||
# Clear the global instance first
|
||||
import core.tools.utils.system_oauth_encryption
|
||||
import core.tools.utils.system_encryption
|
||||
|
||||
core.tools.utils.system_oauth_encryption._oauth_encrypter = None
|
||||
core.tools.utils.system_encryption._encrypter = None
|
||||
|
||||
encrypter1 = get_system_oauth_encrypter()
|
||||
encrypter2 = get_system_oauth_encrypter()
|
||||
encrypter1 = get_system_encrypter()
|
||||
encrypter2 = get_system_encrypter()
|
||||
|
||||
assert encrypter1 is encrypter2
|
||||
assert isinstance(encrypter1, SystemOAuthEncrypter)
|
||||
assert isinstance(encrypter1, SystemEncrypter)
|
||||
|
||||
def test_get_system_oauth_encrypter_uses_config(self):
|
||||
"""Test that global encrypter uses config"""
|
||||
# Clear the global instance first
|
||||
import core.tools.utils.system_oauth_encryption
|
||||
import core.tools.utils.system_encryption
|
||||
|
||||
core.tools.utils.system_oauth_encryption._oauth_encrypter = None
|
||||
core.tools.utils.system_encryption._encrypter = None
|
||||
|
||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "global_secret"
|
||||
encrypter = get_system_oauth_encrypter()
|
||||
encrypter = get_system_encrypter()
|
||||
|
||||
expected_key = hashlib.sha256(b"global_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
@ -427,7 +427,7 @@ class TestConvenienceFunctions:
|
||||
"""Test encrypt_system_oauth_params convenience function"""
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted = encrypt_system_oauth_params(oauth_params)
|
||||
encrypted = encrypt_system_params(oauth_params)
|
||||
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
@ -436,8 +436,8 @@ class TestConvenienceFunctions:
|
||||
"""Test decrypt_system_oauth_params convenience function"""
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted = encrypt_system_oauth_params(oauth_params)
|
||||
decrypted = decrypt_system_oauth_params(encrypted)
|
||||
encrypted = encrypt_system_params(oauth_params)
|
||||
decrypted = decrypt_system_params(encrypted)
|
||||
|
||||
assert decrypted == oauth_params
|
||||
|
||||
@ -453,22 +453,22 @@ class TestConvenienceFunctions:
|
||||
]
|
||||
|
||||
for original_params in test_cases:
|
||||
encrypted = encrypt_system_oauth_params(original_params)
|
||||
decrypted = decrypt_system_oauth_params(encrypted)
|
||||
encrypted = encrypt_system_params(original_params)
|
||||
decrypted = decrypt_system_params(encrypted)
|
||||
assert decrypted == original_params, f"Failed for case: {original_params}"
|
||||
|
||||
def test_convenience_functions_with_errors(self):
|
||||
"""Test convenience functions with error conditions"""
|
||||
# Test encryption with invalid input
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypt_system_oauth_params(None)
|
||||
encrypt_system_params(None)
|
||||
|
||||
# Test decryption with invalid input
|
||||
with pytest.raises(ValueError):
|
||||
decrypt_system_oauth_params("")
|
||||
decrypt_system_params("")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
decrypt_system_oauth_params(None)
|
||||
decrypt_system_params(None)
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
@ -476,14 +476,14 @@ class TestErrorHandling:
|
||||
|
||||
def test_oauth_encryption_error_inheritance(self):
|
||||
"""Test that OAuthEncryptionError is a proper exception"""
|
||||
error = OAuthEncryptionError("Test error")
|
||||
error = EncryptionError("Test error")
|
||||
assert isinstance(error, Exception)
|
||||
assert str(error) == "Test error"
|
||||
|
||||
def test_oauth_encryption_error_with_cause(self):
|
||||
"""Test OAuthEncryptionError with cause"""
|
||||
original_error = ValueError("Original error")
|
||||
error = OAuthEncryptionError("Wrapper error")
|
||||
error = EncryptionError("Wrapper error")
|
||||
error.__cause__ = original_error
|
||||
|
||||
assert isinstance(error, Exception)
|
||||
@ -492,22 +492,22 @@ class TestErrorHandling:
|
||||
|
||||
def test_error_messages_are_informative(self):
|
||||
"""Test that error messages are informative"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
|
||||
# Test empty string error
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params("")
|
||||
encrypter.decrypt_params("")
|
||||
assert "encrypted_data cannot be empty" in str(exc_info.value)
|
||||
|
||||
# Test non-string error
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(123)
|
||||
encrypter.decrypt_params(123)
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
# Test invalid format error
|
||||
short_data = base64.b64encode(b"short").decode()
|
||||
with pytest.raises(OAuthEncryptionError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(short_data)
|
||||
with pytest.raises(EncryptionError) as exc_info:
|
||||
encrypter.decrypt_params(short_data)
|
||||
assert "Invalid encrypted data format" in str(exc_info.value)
|
||||
|
||||
|
||||
@ -517,25 +517,25 @@ class TestEdgeCases:
|
||||
def test_very_long_secret_key(self):
|
||||
"""Test with very long secret key"""
|
||||
long_secret = "x" * 10000
|
||||
encrypter = SystemOAuthEncrypter(long_secret)
|
||||
encrypter = SystemEncrypter(long_secret)
|
||||
|
||||
# Key should still be 32 bytes due to SHA-256
|
||||
assert len(encrypter.key) == 32
|
||||
|
||||
# Should still work normally
|
||||
oauth_params = {"client_id": "test_id"}
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
encrypted = encrypter.encrypt_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_special_characters_in_secret_key(self):
|
||||
"""Test with special characters in secret key"""
|
||||
special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~test🚀"
|
||||
encrypter = SystemOAuthEncrypter(special_secret)
|
||||
encrypter = SystemEncrypter(special_secret)
|
||||
|
||||
oauth_params = {"client_id": "test_id"}
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
encrypted = encrypter.encrypt_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_empty_values_in_oauth_params(self):
|
||||
@ -551,18 +551,18 @@ class TestEdgeCases:
|
||||
"none": None,
|
||||
}
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_deeply_nested_oauth_params(self):
|
||||
"""Test with deeply nested oauth params"""
|
||||
oauth_params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}}
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_oauth_params_with_all_json_types(self):
|
||||
@ -579,9 +579,9 @@ class TestEdgeCases:
|
||||
"object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True},
|
||||
}
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
|
||||
@ -593,27 +593,27 @@ class TestPerformance:
|
||||
large_value = "x" * 100000 # 100KB
|
||||
oauth_params = {"client_id": "test_id", "large_data": large_value}
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_many_fields_oauth_params(self):
|
||||
"""Test with many fields in oauth params"""
|
||||
oauth_params = {f"field_{i}": f"value_{i}" for i in range(1000)}
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_repeated_encryption_decryption(self):
|
||||
"""Test repeated encryption and decryption operations"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypter = SystemEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
# Test multiple rounds of encryption/decryption
|
||||
for i in range(100):
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
encrypted = encrypter.encrypt_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
Loading…
Reference in New Issue
Block a user