From 6e5bf2c8dbb8425c5b6d919a6e1cb468e2b7d61a Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Thu, 18 Sep 2025 13:57:57 -0700 Subject: [PATCH] fix lint --- ollama/_auth.py | 4 +- ollama/_client.py | 5 +- tests/test_auth.py | 353 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 356 insertions(+), 6 deletions(-) create mode 100644 tests/test_auth.py diff --git a/ollama/_auth.py b/ollama/_auth.py index 6d8f975..7efa727 100644 --- a/ollama/_auth.py +++ b/ollama/_auth.py @@ -42,9 +42,9 @@ class OllamaAuth: ) return private_key except FileNotFoundError: - raise FileNotFoundError(f"Could not find Ollama private key at {self.key_path}. Please generate one using: ssh-keygen -t ed25519 -f ~/.ollama/id_ed25519 -N ''") + raise FileNotFoundError(f"Could not find Ollama private key at {self.key_path}. Please generate one using: ssh-keygen -t ed25519 -f ~/.ollama/id_ed25519 -N ''") from None except Exception as e: - raise ValueError(f'Invalid private key at {self.key_path}: {e!s}') + raise ValueError(f'Invalid private key at {self.key_path}: {e!s}') from e def get_public_key_b64(self, private_key): """Get the base64 encoded public key. diff --git a/ollama/_client.py b/ollama/_client.py index c5e61e2..a8660dc 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -124,10 +124,7 @@ class BaseClient: kwargs['headers'] = {} kwargs['headers']['Authorization'] = auth_token - if '?' in path: - path = f'{path}&ts={timestamp}' - else: - path = f'{path}?ts={timestamp}' + path = f'{path}&ts={timestamp}' if '?' in path else f'{path}?ts={timestamp}' return {'method': method, 'url': path, **kwargs} diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..7c4e95c --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,353 @@ +import base64 +import os +import tempfile +import time +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 + +from ollama._auth import OllamaAuth + + +@pytest.fixture +def temp_key_pair(): + """Create a temporary Ed25519 key pair for testing.""" + # Generate a test key pair + private_key = ed25519.Ed25519PrivateKey.generate() + + # Serialize the private key in OpenSSH format + private_key_bytes = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.OpenSSH, + encryption_algorithm=serialization.NoEncryption() + ) + + # Create temporary file + with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='_ed25519') as f: + f.write(private_key_bytes) + temp_key_path = f.name + + yield temp_key_path, private_key + + # Cleanup + try: + os.unlink(temp_key_path) + except FileNotFoundError: + pass + + +@pytest.fixture +def invalid_key_file(): + """Create a temporary file with invalid key content.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_invalid') as f: + f.write("This is not a valid private key") + temp_path = f.name + + yield temp_path + + # Cleanup + try: + os.unlink(temp_path) + except FileNotFoundError: + pass + + +class TestOllamaAuth: + """Test suite for OllamaAuth class.""" + + def test_init_default_key_path(self): + """Test initialization with default key path.""" + auth = OllamaAuth() + expected_path = os.path.join(str(Path.home()), '.ollama', 'id_ed25519') + assert auth.key_path == expected_path + + def test_init_custom_key_path(self): + """Test initialization with custom key path.""" + custom_path = "/custom/path/to/key" + auth = OllamaAuth(key_path=custom_path) + assert auth.key_path == custom_path + + def test_init_expanduser_path(self): + """Test initialization with path containing ~ expansion.""" + auth = OllamaAuth(key_path="~/custom/key") + expected_path = os.path.expanduser("~/custom/key") + assert auth.key_path == expected_path + + def test_init_expandvars_path(self): + """Test initialization with path containing environment variables.""" + with patch.dict(os.environ, {'TEST_DIR': '/test/dir'}): + auth = OllamaAuth(key_path="$TEST_DIR/key") + assert auth.key_path == "/test/dir/key" + + def test_load_private_key_success(self, temp_key_pair): + """Test successful private key loading.""" + temp_key_path, expected_private_key = temp_key_pair + + auth = OllamaAuth(key_path=temp_key_path) + loaded_key = auth.load_private_key() + + # Verify the loaded key is an Ed25519 private key + assert isinstance(loaded_key, ed25519.Ed25519PrivateKey) + + # Verify the public keys match (indirect way to verify it's the same key) + expected_public = expected_private_key.public_key() + loaded_public = loaded_key.public_key() + + expected_bytes = expected_public.public_bytes( + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw + ) + loaded_bytes = loaded_public.public_bytes( + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw + ) + + assert expected_bytes == loaded_bytes + + def test_load_private_key_file_not_found(self): + """Test FileNotFoundError when key file doesn't exist.""" + auth = OllamaAuth(key_path="/nonexistent/path/key") + + with pytest.raises(FileNotFoundError, match="Could not find Ollama private key"): + auth.load_private_key() + + def test_load_private_key_invalid_key(self, invalid_key_file): + """Test ValueError when key file contains invalid data.""" + auth = OllamaAuth(key_path=invalid_key_file) + + with pytest.raises(ValueError, match="Invalid private key"): + auth.load_private_key() + + def test_get_public_key_b64(self, temp_key_pair): + """Test base64 public key extraction.""" + temp_key_path, private_key = temp_key_pair + + auth = OllamaAuth(key_path=temp_key_path) + public_key_b64 = auth.get_public_key_b64(private_key) + + # Verify it's a valid base64 string + try: + decoded = base64.b64decode(public_key_b64) + assert len(decoded) > 0 + except Exception: + pytest.fail("Returned value is not valid base64") + + # Verify the format by checking it matches the OpenSSH format + public_key = private_key.public_key() + openssh_pub = public_key.public_bytes( + encoding=serialization.Encoding.OpenSSH, + format=serialization.PublicFormat.OpenSSH, + ).decode('utf-8').strip() + + expected_b64 = openssh_pub.split(' ')[1] + assert public_key_b64 == expected_b64 + + def test_get_public_key_b64_malformed_key(self): + """Test ValueError when OpenSSH public key is malformed.""" + auth = OllamaAuth() + + # Create a mock private key that produces malformed OpenSSH output + mock_private_key = Mock() + mock_public_key = Mock() + mock_private_key.public_key.return_value = mock_public_key + mock_public_key.public_bytes.return_value = b"malformed" + + with pytest.raises(ValueError, match="Malformed OpenSSH public key"): + auth.get_public_key_b64(mock_private_key) + + def test_sign_request_basic(self, temp_key_pair): + """Test basic request signing functionality.""" + temp_key_path, private_key = temp_key_pair + + auth = OllamaAuth(key_path=temp_key_path) + + method = "POST" + path = "/api/chat" + + with patch('time.time', return_value=1234567890): + auth_token, timestamp = auth.sign_request(method, path) + + assert timestamp == "1234567890" + assert isinstance(auth_token, str) + assert ':' in auth_token # Should contain public_key:signature format + + # Split and verify format + parts = auth_token.split(':') + assert len(parts) == 2 + + public_key_b64, signature_b64 = parts + + # Verify public key is valid base64 + try: + base64.b64decode(public_key_b64) + except Exception: + pytest.fail("Public key part is not valid base64") + + # Verify signature is valid base64 + try: + signature_bytes = base64.b64decode(signature_b64) + assert len(signature_bytes) > 0 + except Exception: + pytest.fail("Signature part is not valid base64") + + def test_sign_request_with_query_params(self, temp_key_pair): + """Test request signing with existing query parameters.""" + temp_key_path, private_key = temp_key_pair + + auth = OllamaAuth(key_path=temp_key_path) + + method = "GET" + path = "/api/models?format=json" + + with patch('time.time', return_value=1234567890): + auth_token, timestamp = auth.sign_request(method, path) + + assert timestamp == "1234567890" + assert isinstance(auth_token, str) + + # The challenge should be "GET,/api/models?format=json&ts=1234567890" + # We can't easily verify the exact signature, but we can verify format + parts = auth_token.split(':') + assert len(parts) == 2 + + def test_sign_request_different_methods(self, temp_key_pair): + """Test request signing with different HTTP methods.""" + temp_key_path, private_key = temp_key_pair + + auth = OllamaAuth(key_path=temp_key_path) + + methods = ["GET", "POST", "PUT", "DELETE"] + path = "/api/test" + + signatures = {} + + for method in methods: + with patch('time.time', return_value=1234567890): + auth_token, timestamp = auth.sign_request(method, path) + signatures[method] = auth_token + + # All signatures should be different (different challenges) + unique_signatures = set(signatures.values()) + assert len(unique_signatures) == len(methods) + + def test_sign_request_different_paths(self, temp_key_pair): + """Test request signing with different paths.""" + temp_key_path, private_key = temp_key_pair + + auth = OllamaAuth(key_path=temp_key_path) + + method = "POST" + paths = ["/api/chat", "/api/generate", "/api/models"] + + signatures = {} + + for path in paths: + with patch('time.time', return_value=1234567890): + auth_token, timestamp = auth.sign_request(method, path) + signatures[path] = auth_token + + # All signatures should be different (different challenges) + unique_signatures = set(signatures.values()) + assert len(unique_signatures) == len(paths) + + def test_sign_request_file_not_found(self): + """Test request signing when key file doesn't exist.""" + auth = OllamaAuth(key_path="/nonexistent/path/key") + + with pytest.raises(FileNotFoundError, match="Could not find Ollama private key"): + auth.sign_request("POST", "/api/chat") + + def test_sign_request_invalid_key(self, invalid_key_file): + """Test request signing with invalid key file.""" + auth = OllamaAuth(key_path=invalid_key_file) + + with pytest.raises(ValueError, match="Invalid private key"): + auth.sign_request("POST", "/api/chat") + + @patch('time.time') + def test_sign_request_timestamp_generation(self, mock_time, temp_key_pair): + """Test that timestamps are generated correctly.""" + temp_key_path, private_key = temp_key_pair + + auth = OllamaAuth(key_path=temp_key_path) + + # Test with different timestamps + mock_time.return_value = 1000.5 + _, timestamp1 = auth.sign_request("POST", "/api/chat") + assert timestamp1 == "1000" + + mock_time.return_value = 2000.9 + _, timestamp2 = auth.sign_request("POST", "/api/chat") + assert timestamp2 == "2000" + + def test_signature_verification_challenge_format(self, temp_key_pair): + """Test that the challenge is formatted correctly for signature verification.""" + temp_key_path, private_key = temp_key_pair + + auth = OllamaAuth(key_path=temp_key_path) + + method = "POST" + path = "/api/chat" + + with patch('time.time', return_value=1234567890): + auth_token, timestamp = auth.sign_request(method, path) + + # Extract signature and verify it was created with correct challenge + public_key_b64, signature_b64 = auth_token.split(':') + signature = base64.b64decode(signature_b64) + + # The challenge should be "POST,/api/chat?ts=1234567890" + expected_challenge = f"{method},{path}?ts={timestamp}" + + # Load the private key and get public key for verification + loaded_private_key = auth.load_private_key() + public_key = loaded_private_key.public_key() + + # Verify the signature + try: + public_key.verify(signature, expected_challenge.encode()) + # If no exception is raised, signature is valid + except Exception: + pytest.fail("Signature verification failed - challenge format incorrect") + + def test_integration_full_flow(self, temp_key_pair): + """Test the complete authentication flow integration.""" + temp_key_path, private_key = temp_key_pair + + # Test complete flow: init -> sign -> verify + auth = OllamaAuth(key_path=temp_key_path) + + method = "POST" + path = "/api/chat" + + # Sign a request + auth_token, timestamp = auth.sign_request(method, path) + + # Verify the components + assert isinstance(auth_token, str) + assert isinstance(timestamp, str) + assert int(timestamp) > 0 # Should be a valid timestamp + + # Verify token format + parts = auth_token.split(':') + assert len(parts) == 2 + + public_key_b64, signature_b64 = parts + + # Verify we can decode both parts + public_key_bytes = base64.b64decode(public_key_b64) + signature_bytes = base64.b64decode(signature_b64) + + assert len(public_key_bytes) > 0 + assert len(signature_bytes) > 0 + + # Verify the signature is valid for the challenge + challenge = f"{method},{path}?ts={timestamp}" + loaded_private_key = auth.load_private_key() + public_key = loaded_private_key.public_key() + + # This should not raise an exception + public_key.verify(signature_bytes, challenge.encode())