ollama-python/tests/test_auth.py
2025-09-18 13:57:57 -07:00

354 lines
13 KiB
Python

import base64
import os
import tempfile
import time
from pathlib import Path
from unittest.mock import Mock, patch
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
from ollama._auth import OllamaAuth
@pytest.fixture
def temp_key_pair():
"""Create a temporary Ed25519 key pair for testing."""
# Generate a test key pair
private_key = ed25519.Ed25519PrivateKey.generate()
# Serialize the private key in OpenSSH format
private_key_bytes = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.OpenSSH,
encryption_algorithm=serialization.NoEncryption()
)
# Create temporary file
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='_ed25519') as f:
f.write(private_key_bytes)
temp_key_path = f.name
yield temp_key_path, private_key
# Cleanup
try:
os.unlink(temp_key_path)
except FileNotFoundError:
pass
@pytest.fixture
def invalid_key_file():
"""Create a temporary file with invalid key content."""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_invalid') as f:
f.write("This is not a valid private key")
temp_path = f.name
yield temp_path
# Cleanup
try:
os.unlink(temp_path)
except FileNotFoundError:
pass
class TestOllamaAuth:
"""Test suite for OllamaAuth class."""
def test_init_default_key_path(self):
"""Test initialization with default key path."""
auth = OllamaAuth()
expected_path = os.path.join(str(Path.home()), '.ollama', 'id_ed25519')
assert auth.key_path == expected_path
def test_init_custom_key_path(self):
"""Test initialization with custom key path."""
custom_path = "/custom/path/to/key"
auth = OllamaAuth(key_path=custom_path)
assert auth.key_path == custom_path
def test_init_expanduser_path(self):
"""Test initialization with path containing ~ expansion."""
auth = OllamaAuth(key_path="~/custom/key")
expected_path = os.path.expanduser("~/custom/key")
assert auth.key_path == expected_path
def test_init_expandvars_path(self):
"""Test initialization with path containing environment variables."""
with patch.dict(os.environ, {'TEST_DIR': '/test/dir'}):
auth = OllamaAuth(key_path="$TEST_DIR/key")
assert auth.key_path == "/test/dir/key"
def test_load_private_key_success(self, temp_key_pair):
"""Test successful private key loading."""
temp_key_path, expected_private_key = temp_key_pair
auth = OllamaAuth(key_path=temp_key_path)
loaded_key = auth.load_private_key()
# Verify the loaded key is an Ed25519 private key
assert isinstance(loaded_key, ed25519.Ed25519PrivateKey)
# Verify the public keys match (indirect way to verify it's the same key)
expected_public = expected_private_key.public_key()
loaded_public = loaded_key.public_key()
expected_bytes = expected_public.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw
)
loaded_bytes = loaded_public.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw
)
assert expected_bytes == loaded_bytes
def test_load_private_key_file_not_found(self):
"""Test FileNotFoundError when key file doesn't exist."""
auth = OllamaAuth(key_path="/nonexistent/path/key")
with pytest.raises(FileNotFoundError, match="Could not find Ollama private key"):
auth.load_private_key()
def test_load_private_key_invalid_key(self, invalid_key_file):
"""Test ValueError when key file contains invalid data."""
auth = OllamaAuth(key_path=invalid_key_file)
with pytest.raises(ValueError, match="Invalid private key"):
auth.load_private_key()
def test_get_public_key_b64(self, temp_key_pair):
"""Test base64 public key extraction."""
temp_key_path, private_key = temp_key_pair
auth = OllamaAuth(key_path=temp_key_path)
public_key_b64 = auth.get_public_key_b64(private_key)
# Verify it's a valid base64 string
try:
decoded = base64.b64decode(public_key_b64)
assert len(decoded) > 0
except Exception:
pytest.fail("Returned value is not valid base64")
# Verify the format by checking it matches the OpenSSH format
public_key = private_key.public_key()
openssh_pub = public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH,
).decode('utf-8').strip()
expected_b64 = openssh_pub.split(' ')[1]
assert public_key_b64 == expected_b64
def test_get_public_key_b64_malformed_key(self):
"""Test ValueError when OpenSSH public key is malformed."""
auth = OllamaAuth()
# Create a mock private key that produces malformed OpenSSH output
mock_private_key = Mock()
mock_public_key = Mock()
mock_private_key.public_key.return_value = mock_public_key
mock_public_key.public_bytes.return_value = b"malformed"
with pytest.raises(ValueError, match="Malformed OpenSSH public key"):
auth.get_public_key_b64(mock_private_key)
def test_sign_request_basic(self, temp_key_pair):
"""Test basic request signing functionality."""
temp_key_path, private_key = temp_key_pair
auth = OllamaAuth(key_path=temp_key_path)
method = "POST"
path = "/api/chat"
with patch('time.time', return_value=1234567890):
auth_token, timestamp = auth.sign_request(method, path)
assert timestamp == "1234567890"
assert isinstance(auth_token, str)
assert ':' in auth_token # Should contain public_key:signature format
# Split and verify format
parts = auth_token.split(':')
assert len(parts) == 2
public_key_b64, signature_b64 = parts
# Verify public key is valid base64
try:
base64.b64decode(public_key_b64)
except Exception:
pytest.fail("Public key part is not valid base64")
# Verify signature is valid base64
try:
signature_bytes = base64.b64decode(signature_b64)
assert len(signature_bytes) > 0
except Exception:
pytest.fail("Signature part is not valid base64")
def test_sign_request_with_query_params(self, temp_key_pair):
"""Test request signing with existing query parameters."""
temp_key_path, private_key = temp_key_pair
auth = OllamaAuth(key_path=temp_key_path)
method = "GET"
path = "/api/models?format=json"
with patch('time.time', return_value=1234567890):
auth_token, timestamp = auth.sign_request(method, path)
assert timestamp == "1234567890"
assert isinstance(auth_token, str)
# The challenge should be "GET,/api/models?format=json&ts=1234567890"
# We can't easily verify the exact signature, but we can verify format
parts = auth_token.split(':')
assert len(parts) == 2
def test_sign_request_different_methods(self, temp_key_pair):
"""Test request signing with different HTTP methods."""
temp_key_path, private_key = temp_key_pair
auth = OllamaAuth(key_path=temp_key_path)
methods = ["GET", "POST", "PUT", "DELETE"]
path = "/api/test"
signatures = {}
for method in methods:
with patch('time.time', return_value=1234567890):
auth_token, timestamp = auth.sign_request(method, path)
signatures[method] = auth_token
# All signatures should be different (different challenges)
unique_signatures = set(signatures.values())
assert len(unique_signatures) == len(methods)
def test_sign_request_different_paths(self, temp_key_pair):
"""Test request signing with different paths."""
temp_key_path, private_key = temp_key_pair
auth = OllamaAuth(key_path=temp_key_path)
method = "POST"
paths = ["/api/chat", "/api/generate", "/api/models"]
signatures = {}
for path in paths:
with patch('time.time', return_value=1234567890):
auth_token, timestamp = auth.sign_request(method, path)
signatures[path] = auth_token
# All signatures should be different (different challenges)
unique_signatures = set(signatures.values())
assert len(unique_signatures) == len(paths)
def test_sign_request_file_not_found(self):
"""Test request signing when key file doesn't exist."""
auth = OllamaAuth(key_path="/nonexistent/path/key")
with pytest.raises(FileNotFoundError, match="Could not find Ollama private key"):
auth.sign_request("POST", "/api/chat")
def test_sign_request_invalid_key(self, invalid_key_file):
"""Test request signing with invalid key file."""
auth = OllamaAuth(key_path=invalid_key_file)
with pytest.raises(ValueError, match="Invalid private key"):
auth.sign_request("POST", "/api/chat")
@patch('time.time')
def test_sign_request_timestamp_generation(self, mock_time, temp_key_pair):
"""Test that timestamps are generated correctly."""
temp_key_path, private_key = temp_key_pair
auth = OllamaAuth(key_path=temp_key_path)
# Test with different timestamps
mock_time.return_value = 1000.5
_, timestamp1 = auth.sign_request("POST", "/api/chat")
assert timestamp1 == "1000"
mock_time.return_value = 2000.9
_, timestamp2 = auth.sign_request("POST", "/api/chat")
assert timestamp2 == "2000"
def test_signature_verification_challenge_format(self, temp_key_pair):
"""Test that the challenge is formatted correctly for signature verification."""
temp_key_path, private_key = temp_key_pair
auth = OllamaAuth(key_path=temp_key_path)
method = "POST"
path = "/api/chat"
with patch('time.time', return_value=1234567890):
auth_token, timestamp = auth.sign_request(method, path)
# Extract signature and verify it was created with correct challenge
public_key_b64, signature_b64 = auth_token.split(':')
signature = base64.b64decode(signature_b64)
# The challenge should be "POST,/api/chat?ts=1234567890"
expected_challenge = f"{method},{path}?ts={timestamp}"
# Load the private key and get public key for verification
loaded_private_key = auth.load_private_key()
public_key = loaded_private_key.public_key()
# Verify the signature
try:
public_key.verify(signature, expected_challenge.encode())
# If no exception is raised, signature is valid
except Exception:
pytest.fail("Signature verification failed - challenge format incorrect")
def test_integration_full_flow(self, temp_key_pair):
"""Test the complete authentication flow integration."""
temp_key_path, private_key = temp_key_pair
# Test complete flow: init -> sign -> verify
auth = OllamaAuth(key_path=temp_key_path)
method = "POST"
path = "/api/chat"
# Sign a request
auth_token, timestamp = auth.sign_request(method, path)
# Verify the components
assert isinstance(auth_token, str)
assert isinstance(timestamp, str)
assert int(timestamp) > 0 # Should be a valid timestamp
# Verify token format
parts = auth_token.split(':')
assert len(parts) == 2
public_key_b64, signature_b64 = parts
# Verify we can decode both parts
public_key_bytes = base64.b64decode(public_key_b64)
signature_bytes = base64.b64decode(signature_b64)
assert len(public_key_bytes) > 0
assert len(signature_bytes) > 0
# Verify the signature is valid for the challenge
challenge = f"{method},{path}?ts={timestamp}"
loaded_private_key = auth.load_private_key()
public_key = loaded_private_key.public_key()
# This should not raise an exception
public_key.verify(signature_bytes, challenge.encode())