mirror of
https://github.com/ollama/ollama-python.git
synced 2026-01-13 21:57:16 +08:00
354 lines
13 KiB
Python
354 lines
13 KiB
Python
import base64
|
|
import os
|
|
import tempfile
|
|
import time
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
from cryptography.hazmat.primitives import serialization
|
|
from cryptography.hazmat.primitives.asymmetric import ed25519
|
|
|
|
from ollama._auth import OllamaAuth
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_key_pair():
|
|
"""Create a temporary Ed25519 key pair for testing."""
|
|
# Generate a test key pair
|
|
private_key = ed25519.Ed25519PrivateKey.generate()
|
|
|
|
# Serialize the private key in OpenSSH format
|
|
private_key_bytes = private_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.OpenSSH,
|
|
encryption_algorithm=serialization.NoEncryption()
|
|
)
|
|
|
|
# Create temporary file
|
|
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='_ed25519') as f:
|
|
f.write(private_key_bytes)
|
|
temp_key_path = f.name
|
|
|
|
yield temp_key_path, private_key
|
|
|
|
# Cleanup
|
|
try:
|
|
os.unlink(temp_key_path)
|
|
except FileNotFoundError:
|
|
pass
|
|
|
|
|
|
@pytest.fixture
|
|
def invalid_key_file():
|
|
"""Create a temporary file with invalid key content."""
|
|
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_invalid') as f:
|
|
f.write("This is not a valid private key")
|
|
temp_path = f.name
|
|
|
|
yield temp_path
|
|
|
|
# Cleanup
|
|
try:
|
|
os.unlink(temp_path)
|
|
except FileNotFoundError:
|
|
pass
|
|
|
|
|
|
class TestOllamaAuth:
|
|
"""Test suite for OllamaAuth class."""
|
|
|
|
def test_init_default_key_path(self):
|
|
"""Test initialization with default key path."""
|
|
auth = OllamaAuth()
|
|
expected_path = os.path.join(str(Path.home()), '.ollama', 'id_ed25519')
|
|
assert auth.key_path == expected_path
|
|
|
|
def test_init_custom_key_path(self):
|
|
"""Test initialization with custom key path."""
|
|
custom_path = "/custom/path/to/key"
|
|
auth = OllamaAuth(key_path=custom_path)
|
|
assert auth.key_path == custom_path
|
|
|
|
def test_init_expanduser_path(self):
|
|
"""Test initialization with path containing ~ expansion."""
|
|
auth = OllamaAuth(key_path="~/custom/key")
|
|
expected_path = os.path.expanduser("~/custom/key")
|
|
assert auth.key_path == expected_path
|
|
|
|
def test_init_expandvars_path(self):
|
|
"""Test initialization with path containing environment variables."""
|
|
with patch.dict(os.environ, {'TEST_DIR': '/test/dir'}):
|
|
auth = OllamaAuth(key_path="$TEST_DIR/key")
|
|
assert auth.key_path == "/test/dir/key"
|
|
|
|
def test_load_private_key_success(self, temp_key_pair):
|
|
"""Test successful private key loading."""
|
|
temp_key_path, expected_private_key = temp_key_pair
|
|
|
|
auth = OllamaAuth(key_path=temp_key_path)
|
|
loaded_key = auth.load_private_key()
|
|
|
|
# Verify the loaded key is an Ed25519 private key
|
|
assert isinstance(loaded_key, ed25519.Ed25519PrivateKey)
|
|
|
|
# Verify the public keys match (indirect way to verify it's the same key)
|
|
expected_public = expected_private_key.public_key()
|
|
loaded_public = loaded_key.public_key()
|
|
|
|
expected_bytes = expected_public.public_bytes(
|
|
encoding=serialization.Encoding.Raw,
|
|
format=serialization.PublicFormat.Raw
|
|
)
|
|
loaded_bytes = loaded_public.public_bytes(
|
|
encoding=serialization.Encoding.Raw,
|
|
format=serialization.PublicFormat.Raw
|
|
)
|
|
|
|
assert expected_bytes == loaded_bytes
|
|
|
|
def test_load_private_key_file_not_found(self):
|
|
"""Test FileNotFoundError when key file doesn't exist."""
|
|
auth = OllamaAuth(key_path="/nonexistent/path/key")
|
|
|
|
with pytest.raises(FileNotFoundError, match="Could not find Ollama private key"):
|
|
auth.load_private_key()
|
|
|
|
def test_load_private_key_invalid_key(self, invalid_key_file):
|
|
"""Test ValueError when key file contains invalid data."""
|
|
auth = OllamaAuth(key_path=invalid_key_file)
|
|
|
|
with pytest.raises(ValueError, match="Invalid private key"):
|
|
auth.load_private_key()
|
|
|
|
def test_get_public_key_b64(self, temp_key_pair):
|
|
"""Test base64 public key extraction."""
|
|
temp_key_path, private_key = temp_key_pair
|
|
|
|
auth = OllamaAuth(key_path=temp_key_path)
|
|
public_key_b64 = auth.get_public_key_b64(private_key)
|
|
|
|
# Verify it's a valid base64 string
|
|
try:
|
|
decoded = base64.b64decode(public_key_b64)
|
|
assert len(decoded) > 0
|
|
except Exception:
|
|
pytest.fail("Returned value is not valid base64")
|
|
|
|
# Verify the format by checking it matches the OpenSSH format
|
|
public_key = private_key.public_key()
|
|
openssh_pub = public_key.public_bytes(
|
|
encoding=serialization.Encoding.OpenSSH,
|
|
format=serialization.PublicFormat.OpenSSH,
|
|
).decode('utf-8').strip()
|
|
|
|
expected_b64 = openssh_pub.split(' ')[1]
|
|
assert public_key_b64 == expected_b64
|
|
|
|
def test_get_public_key_b64_malformed_key(self):
|
|
"""Test ValueError when OpenSSH public key is malformed."""
|
|
auth = OllamaAuth()
|
|
|
|
# Create a mock private key that produces malformed OpenSSH output
|
|
mock_private_key = Mock()
|
|
mock_public_key = Mock()
|
|
mock_private_key.public_key.return_value = mock_public_key
|
|
mock_public_key.public_bytes.return_value = b"malformed"
|
|
|
|
with pytest.raises(ValueError, match="Malformed OpenSSH public key"):
|
|
auth.get_public_key_b64(mock_private_key)
|
|
|
|
def test_sign_request_basic(self, temp_key_pair):
|
|
"""Test basic request signing functionality."""
|
|
temp_key_path, private_key = temp_key_pair
|
|
|
|
auth = OllamaAuth(key_path=temp_key_path)
|
|
|
|
method = "POST"
|
|
path = "/api/chat"
|
|
|
|
with patch('time.time', return_value=1234567890):
|
|
auth_token, timestamp = auth.sign_request(method, path)
|
|
|
|
assert timestamp == "1234567890"
|
|
assert isinstance(auth_token, str)
|
|
assert ':' in auth_token # Should contain public_key:signature format
|
|
|
|
# Split and verify format
|
|
parts = auth_token.split(':')
|
|
assert len(parts) == 2
|
|
|
|
public_key_b64, signature_b64 = parts
|
|
|
|
# Verify public key is valid base64
|
|
try:
|
|
base64.b64decode(public_key_b64)
|
|
except Exception:
|
|
pytest.fail("Public key part is not valid base64")
|
|
|
|
# Verify signature is valid base64
|
|
try:
|
|
signature_bytes = base64.b64decode(signature_b64)
|
|
assert len(signature_bytes) > 0
|
|
except Exception:
|
|
pytest.fail("Signature part is not valid base64")
|
|
|
|
def test_sign_request_with_query_params(self, temp_key_pair):
|
|
"""Test request signing with existing query parameters."""
|
|
temp_key_path, private_key = temp_key_pair
|
|
|
|
auth = OllamaAuth(key_path=temp_key_path)
|
|
|
|
method = "GET"
|
|
path = "/api/models?format=json"
|
|
|
|
with patch('time.time', return_value=1234567890):
|
|
auth_token, timestamp = auth.sign_request(method, path)
|
|
|
|
assert timestamp == "1234567890"
|
|
assert isinstance(auth_token, str)
|
|
|
|
# The challenge should be "GET,/api/models?format=json&ts=1234567890"
|
|
# We can't easily verify the exact signature, but we can verify format
|
|
parts = auth_token.split(':')
|
|
assert len(parts) == 2
|
|
|
|
def test_sign_request_different_methods(self, temp_key_pair):
|
|
"""Test request signing with different HTTP methods."""
|
|
temp_key_path, private_key = temp_key_pair
|
|
|
|
auth = OllamaAuth(key_path=temp_key_path)
|
|
|
|
methods = ["GET", "POST", "PUT", "DELETE"]
|
|
path = "/api/test"
|
|
|
|
signatures = {}
|
|
|
|
for method in methods:
|
|
with patch('time.time', return_value=1234567890):
|
|
auth_token, timestamp = auth.sign_request(method, path)
|
|
signatures[method] = auth_token
|
|
|
|
# All signatures should be different (different challenges)
|
|
unique_signatures = set(signatures.values())
|
|
assert len(unique_signatures) == len(methods)
|
|
|
|
def test_sign_request_different_paths(self, temp_key_pair):
|
|
"""Test request signing with different paths."""
|
|
temp_key_path, private_key = temp_key_pair
|
|
|
|
auth = OllamaAuth(key_path=temp_key_path)
|
|
|
|
method = "POST"
|
|
paths = ["/api/chat", "/api/generate", "/api/models"]
|
|
|
|
signatures = {}
|
|
|
|
for path in paths:
|
|
with patch('time.time', return_value=1234567890):
|
|
auth_token, timestamp = auth.sign_request(method, path)
|
|
signatures[path] = auth_token
|
|
|
|
# All signatures should be different (different challenges)
|
|
unique_signatures = set(signatures.values())
|
|
assert len(unique_signatures) == len(paths)
|
|
|
|
def test_sign_request_file_not_found(self):
|
|
"""Test request signing when key file doesn't exist."""
|
|
auth = OllamaAuth(key_path="/nonexistent/path/key")
|
|
|
|
with pytest.raises(FileNotFoundError, match="Could not find Ollama private key"):
|
|
auth.sign_request("POST", "/api/chat")
|
|
|
|
def test_sign_request_invalid_key(self, invalid_key_file):
|
|
"""Test request signing with invalid key file."""
|
|
auth = OllamaAuth(key_path=invalid_key_file)
|
|
|
|
with pytest.raises(ValueError, match="Invalid private key"):
|
|
auth.sign_request("POST", "/api/chat")
|
|
|
|
@patch('time.time')
|
|
def test_sign_request_timestamp_generation(self, mock_time, temp_key_pair):
|
|
"""Test that timestamps are generated correctly."""
|
|
temp_key_path, private_key = temp_key_pair
|
|
|
|
auth = OllamaAuth(key_path=temp_key_path)
|
|
|
|
# Test with different timestamps
|
|
mock_time.return_value = 1000.5
|
|
_, timestamp1 = auth.sign_request("POST", "/api/chat")
|
|
assert timestamp1 == "1000"
|
|
|
|
mock_time.return_value = 2000.9
|
|
_, timestamp2 = auth.sign_request("POST", "/api/chat")
|
|
assert timestamp2 == "2000"
|
|
|
|
def test_signature_verification_challenge_format(self, temp_key_pair):
|
|
"""Test that the challenge is formatted correctly for signature verification."""
|
|
temp_key_path, private_key = temp_key_pair
|
|
|
|
auth = OllamaAuth(key_path=temp_key_path)
|
|
|
|
method = "POST"
|
|
path = "/api/chat"
|
|
|
|
with patch('time.time', return_value=1234567890):
|
|
auth_token, timestamp = auth.sign_request(method, path)
|
|
|
|
# Extract signature and verify it was created with correct challenge
|
|
public_key_b64, signature_b64 = auth_token.split(':')
|
|
signature = base64.b64decode(signature_b64)
|
|
|
|
# The challenge should be "POST,/api/chat?ts=1234567890"
|
|
expected_challenge = f"{method},{path}?ts={timestamp}"
|
|
|
|
# Load the private key and get public key for verification
|
|
loaded_private_key = auth.load_private_key()
|
|
public_key = loaded_private_key.public_key()
|
|
|
|
# Verify the signature
|
|
try:
|
|
public_key.verify(signature, expected_challenge.encode())
|
|
# If no exception is raised, signature is valid
|
|
except Exception:
|
|
pytest.fail("Signature verification failed - challenge format incorrect")
|
|
|
|
def test_integration_full_flow(self, temp_key_pair):
|
|
"""Test the complete authentication flow integration."""
|
|
temp_key_path, private_key = temp_key_pair
|
|
|
|
# Test complete flow: init -> sign -> verify
|
|
auth = OllamaAuth(key_path=temp_key_path)
|
|
|
|
method = "POST"
|
|
path = "/api/chat"
|
|
|
|
# Sign a request
|
|
auth_token, timestamp = auth.sign_request(method, path)
|
|
|
|
# Verify the components
|
|
assert isinstance(auth_token, str)
|
|
assert isinstance(timestamp, str)
|
|
assert int(timestamp) > 0 # Should be a valid timestamp
|
|
|
|
# Verify token format
|
|
parts = auth_token.split(':')
|
|
assert len(parts) == 2
|
|
|
|
public_key_b64, signature_b64 = parts
|
|
|
|
# Verify we can decode both parts
|
|
public_key_bytes = base64.b64decode(public_key_b64)
|
|
signature_bytes = base64.b64decode(signature_b64)
|
|
|
|
assert len(public_key_bytes) > 0
|
|
assert len(signature_bytes) > 0
|
|
|
|
# Verify the signature is valid for the challenge
|
|
challenge = f"{method},{path}?ts={timestamp}"
|
|
loaded_private_key = auth.load_private_key()
|
|
public_key = loaded_private_key.public_key()
|
|
|
|
# This should not raise an exception
|
|
public_key.verify(signature_bytes, challenge.encode())
|