mirror of
https://github.com/ollama/ollama-python.git
synced 2026-01-13 21:57:16 +08:00
fix lint
This commit is contained in:
parent
9a90929137
commit
6e5bf2c8db
@ -42,9 +42,9 @@ class OllamaAuth:
|
||||
)
|
||||
return private_key
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Could not find Ollama private key at {self.key_path}. Please generate one using: ssh-keygen -t ed25519 -f ~/.ollama/id_ed25519 -N ''")
|
||||
raise FileNotFoundError(f"Could not find Ollama private key at {self.key_path}. Please generate one using: ssh-keygen -t ed25519 -f ~/.ollama/id_ed25519 -N ''") from None
|
||||
except Exception as e:
|
||||
raise ValueError(f'Invalid private key at {self.key_path}: {e!s}')
|
||||
raise ValueError(f'Invalid private key at {self.key_path}: {e!s}') from e
|
||||
|
||||
def get_public_key_b64(self, private_key):
|
||||
"""Get the base64 encoded public key.
|
||||
|
||||
@ -124,10 +124,7 @@ class BaseClient:
|
||||
kwargs['headers'] = {}
|
||||
kwargs['headers']['Authorization'] = auth_token
|
||||
|
||||
if '?' in path:
|
||||
path = f'{path}&ts={timestamp}'
|
||||
else:
|
||||
path = f'{path}?ts={timestamp}'
|
||||
path = f'{path}&ts={timestamp}' if '?' in path else f'{path}?ts={timestamp}'
|
||||
|
||||
return {'method': method, 'url': path, **kwargs}
|
||||
|
||||
|
||||
353
tests/test_auth.py
Normal file
353
tests/test_auth.py
Normal file
@ -0,0 +1,353 @@
|
||||
import base64
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
from ollama._auth import OllamaAuth
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_key_pair():
|
||||
"""Create a temporary Ed25519 key pair for testing."""
|
||||
# Generate a test key pair
|
||||
private_key = ed25519.Ed25519PrivateKey.generate()
|
||||
|
||||
# Serialize the private key in OpenSSH format
|
||||
private_key_bytes = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.OpenSSH,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
)
|
||||
|
||||
# Create temporary file
|
||||
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='_ed25519') as f:
|
||||
f.write(private_key_bytes)
|
||||
temp_key_path = f.name
|
||||
|
||||
yield temp_key_path, private_key
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
os.unlink(temp_key_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_key_file():
|
||||
"""Create a temporary file with invalid key content."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_invalid') as f:
|
||||
f.write("This is not a valid private key")
|
||||
temp_path = f.name
|
||||
|
||||
yield temp_path
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
class TestOllamaAuth:
|
||||
"""Test suite for OllamaAuth class."""
|
||||
|
||||
def test_init_default_key_path(self):
|
||||
"""Test initialization with default key path."""
|
||||
auth = OllamaAuth()
|
||||
expected_path = os.path.join(str(Path.home()), '.ollama', 'id_ed25519')
|
||||
assert auth.key_path == expected_path
|
||||
|
||||
def test_init_custom_key_path(self):
|
||||
"""Test initialization with custom key path."""
|
||||
custom_path = "/custom/path/to/key"
|
||||
auth = OllamaAuth(key_path=custom_path)
|
||||
assert auth.key_path == custom_path
|
||||
|
||||
def test_init_expanduser_path(self):
|
||||
"""Test initialization with path containing ~ expansion."""
|
||||
auth = OllamaAuth(key_path="~/custom/key")
|
||||
expected_path = os.path.expanduser("~/custom/key")
|
||||
assert auth.key_path == expected_path
|
||||
|
||||
def test_init_expandvars_path(self):
|
||||
"""Test initialization with path containing environment variables."""
|
||||
with patch.dict(os.environ, {'TEST_DIR': '/test/dir'}):
|
||||
auth = OllamaAuth(key_path="$TEST_DIR/key")
|
||||
assert auth.key_path == "/test/dir/key"
|
||||
|
||||
def test_load_private_key_success(self, temp_key_pair):
|
||||
"""Test successful private key loading."""
|
||||
temp_key_path, expected_private_key = temp_key_pair
|
||||
|
||||
auth = OllamaAuth(key_path=temp_key_path)
|
||||
loaded_key = auth.load_private_key()
|
||||
|
||||
# Verify the loaded key is an Ed25519 private key
|
||||
assert isinstance(loaded_key, ed25519.Ed25519PrivateKey)
|
||||
|
||||
# Verify the public keys match (indirect way to verify it's the same key)
|
||||
expected_public = expected_private_key.public_key()
|
||||
loaded_public = loaded_key.public_key()
|
||||
|
||||
expected_bytes = expected_public.public_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PublicFormat.Raw
|
||||
)
|
||||
loaded_bytes = loaded_public.public_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PublicFormat.Raw
|
||||
)
|
||||
|
||||
assert expected_bytes == loaded_bytes
|
||||
|
||||
def test_load_private_key_file_not_found(self):
|
||||
"""Test FileNotFoundError when key file doesn't exist."""
|
||||
auth = OllamaAuth(key_path="/nonexistent/path/key")
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="Could not find Ollama private key"):
|
||||
auth.load_private_key()
|
||||
|
||||
def test_load_private_key_invalid_key(self, invalid_key_file):
|
||||
"""Test ValueError when key file contains invalid data."""
|
||||
auth = OllamaAuth(key_path=invalid_key_file)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid private key"):
|
||||
auth.load_private_key()
|
||||
|
||||
def test_get_public_key_b64(self, temp_key_pair):
|
||||
"""Test base64 public key extraction."""
|
||||
temp_key_path, private_key = temp_key_pair
|
||||
|
||||
auth = OllamaAuth(key_path=temp_key_path)
|
||||
public_key_b64 = auth.get_public_key_b64(private_key)
|
||||
|
||||
# Verify it's a valid base64 string
|
||||
try:
|
||||
decoded = base64.b64decode(public_key_b64)
|
||||
assert len(decoded) > 0
|
||||
except Exception:
|
||||
pytest.fail("Returned value is not valid base64")
|
||||
|
||||
# Verify the format by checking it matches the OpenSSH format
|
||||
public_key = private_key.public_key()
|
||||
openssh_pub = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.OpenSSH,
|
||||
format=serialization.PublicFormat.OpenSSH,
|
||||
).decode('utf-8').strip()
|
||||
|
||||
expected_b64 = openssh_pub.split(' ')[1]
|
||||
assert public_key_b64 == expected_b64
|
||||
|
||||
def test_get_public_key_b64_malformed_key(self):
|
||||
"""Test ValueError when OpenSSH public key is malformed."""
|
||||
auth = OllamaAuth()
|
||||
|
||||
# Create a mock private key that produces malformed OpenSSH output
|
||||
mock_private_key = Mock()
|
||||
mock_public_key = Mock()
|
||||
mock_private_key.public_key.return_value = mock_public_key
|
||||
mock_public_key.public_bytes.return_value = b"malformed"
|
||||
|
||||
with pytest.raises(ValueError, match="Malformed OpenSSH public key"):
|
||||
auth.get_public_key_b64(mock_private_key)
|
||||
|
||||
def test_sign_request_basic(self, temp_key_pair):
|
||||
"""Test basic request signing functionality."""
|
||||
temp_key_path, private_key = temp_key_pair
|
||||
|
||||
auth = OllamaAuth(key_path=temp_key_path)
|
||||
|
||||
method = "POST"
|
||||
path = "/api/chat"
|
||||
|
||||
with patch('time.time', return_value=1234567890):
|
||||
auth_token, timestamp = auth.sign_request(method, path)
|
||||
|
||||
assert timestamp == "1234567890"
|
||||
assert isinstance(auth_token, str)
|
||||
assert ':' in auth_token # Should contain public_key:signature format
|
||||
|
||||
# Split and verify format
|
||||
parts = auth_token.split(':')
|
||||
assert len(parts) == 2
|
||||
|
||||
public_key_b64, signature_b64 = parts
|
||||
|
||||
# Verify public key is valid base64
|
||||
try:
|
||||
base64.b64decode(public_key_b64)
|
||||
except Exception:
|
||||
pytest.fail("Public key part is not valid base64")
|
||||
|
||||
# Verify signature is valid base64
|
||||
try:
|
||||
signature_bytes = base64.b64decode(signature_b64)
|
||||
assert len(signature_bytes) > 0
|
||||
except Exception:
|
||||
pytest.fail("Signature part is not valid base64")
|
||||
|
||||
def test_sign_request_with_query_params(self, temp_key_pair):
|
||||
"""Test request signing with existing query parameters."""
|
||||
temp_key_path, private_key = temp_key_pair
|
||||
|
||||
auth = OllamaAuth(key_path=temp_key_path)
|
||||
|
||||
method = "GET"
|
||||
path = "/api/models?format=json"
|
||||
|
||||
with patch('time.time', return_value=1234567890):
|
||||
auth_token, timestamp = auth.sign_request(method, path)
|
||||
|
||||
assert timestamp == "1234567890"
|
||||
assert isinstance(auth_token, str)
|
||||
|
||||
# The challenge should be "GET,/api/models?format=json&ts=1234567890"
|
||||
# We can't easily verify the exact signature, but we can verify format
|
||||
parts = auth_token.split(':')
|
||||
assert len(parts) == 2
|
||||
|
||||
def test_sign_request_different_methods(self, temp_key_pair):
|
||||
"""Test request signing with different HTTP methods."""
|
||||
temp_key_path, private_key = temp_key_pair
|
||||
|
||||
auth = OllamaAuth(key_path=temp_key_path)
|
||||
|
||||
methods = ["GET", "POST", "PUT", "DELETE"]
|
||||
path = "/api/test"
|
||||
|
||||
signatures = {}
|
||||
|
||||
for method in methods:
|
||||
with patch('time.time', return_value=1234567890):
|
||||
auth_token, timestamp = auth.sign_request(method, path)
|
||||
signatures[method] = auth_token
|
||||
|
||||
# All signatures should be different (different challenges)
|
||||
unique_signatures = set(signatures.values())
|
||||
assert len(unique_signatures) == len(methods)
|
||||
|
||||
def test_sign_request_different_paths(self, temp_key_pair):
|
||||
"""Test request signing with different paths."""
|
||||
temp_key_path, private_key = temp_key_pair
|
||||
|
||||
auth = OllamaAuth(key_path=temp_key_path)
|
||||
|
||||
method = "POST"
|
||||
paths = ["/api/chat", "/api/generate", "/api/models"]
|
||||
|
||||
signatures = {}
|
||||
|
||||
for path in paths:
|
||||
with patch('time.time', return_value=1234567890):
|
||||
auth_token, timestamp = auth.sign_request(method, path)
|
||||
signatures[path] = auth_token
|
||||
|
||||
# All signatures should be different (different challenges)
|
||||
unique_signatures = set(signatures.values())
|
||||
assert len(unique_signatures) == len(paths)
|
||||
|
||||
def test_sign_request_file_not_found(self):
|
||||
"""Test request signing when key file doesn't exist."""
|
||||
auth = OllamaAuth(key_path="/nonexistent/path/key")
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="Could not find Ollama private key"):
|
||||
auth.sign_request("POST", "/api/chat")
|
||||
|
||||
def test_sign_request_invalid_key(self, invalid_key_file):
|
||||
"""Test request signing with invalid key file."""
|
||||
auth = OllamaAuth(key_path=invalid_key_file)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid private key"):
|
||||
auth.sign_request("POST", "/api/chat")
|
||||
|
||||
@patch('time.time')
|
||||
def test_sign_request_timestamp_generation(self, mock_time, temp_key_pair):
|
||||
"""Test that timestamps are generated correctly."""
|
||||
temp_key_path, private_key = temp_key_pair
|
||||
|
||||
auth = OllamaAuth(key_path=temp_key_path)
|
||||
|
||||
# Test with different timestamps
|
||||
mock_time.return_value = 1000.5
|
||||
_, timestamp1 = auth.sign_request("POST", "/api/chat")
|
||||
assert timestamp1 == "1000"
|
||||
|
||||
mock_time.return_value = 2000.9
|
||||
_, timestamp2 = auth.sign_request("POST", "/api/chat")
|
||||
assert timestamp2 == "2000"
|
||||
|
||||
def test_signature_verification_challenge_format(self, temp_key_pair):
|
||||
"""Test that the challenge is formatted correctly for signature verification."""
|
||||
temp_key_path, private_key = temp_key_pair
|
||||
|
||||
auth = OllamaAuth(key_path=temp_key_path)
|
||||
|
||||
method = "POST"
|
||||
path = "/api/chat"
|
||||
|
||||
with patch('time.time', return_value=1234567890):
|
||||
auth_token, timestamp = auth.sign_request(method, path)
|
||||
|
||||
# Extract signature and verify it was created with correct challenge
|
||||
public_key_b64, signature_b64 = auth_token.split(':')
|
||||
signature = base64.b64decode(signature_b64)
|
||||
|
||||
# The challenge should be "POST,/api/chat?ts=1234567890"
|
||||
expected_challenge = f"{method},{path}?ts={timestamp}"
|
||||
|
||||
# Load the private key and get public key for verification
|
||||
loaded_private_key = auth.load_private_key()
|
||||
public_key = loaded_private_key.public_key()
|
||||
|
||||
# Verify the signature
|
||||
try:
|
||||
public_key.verify(signature, expected_challenge.encode())
|
||||
# If no exception is raised, signature is valid
|
||||
except Exception:
|
||||
pytest.fail("Signature verification failed - challenge format incorrect")
|
||||
|
||||
def test_integration_full_flow(self, temp_key_pair):
|
||||
"""Test the complete authentication flow integration."""
|
||||
temp_key_path, private_key = temp_key_pair
|
||||
|
||||
# Test complete flow: init -> sign -> verify
|
||||
auth = OllamaAuth(key_path=temp_key_path)
|
||||
|
||||
method = "POST"
|
||||
path = "/api/chat"
|
||||
|
||||
# Sign a request
|
||||
auth_token, timestamp = auth.sign_request(method, path)
|
||||
|
||||
# Verify the components
|
||||
assert isinstance(auth_token, str)
|
||||
assert isinstance(timestamp, str)
|
||||
assert int(timestamp) > 0 # Should be a valid timestamp
|
||||
|
||||
# Verify token format
|
||||
parts = auth_token.split(':')
|
||||
assert len(parts) == 2
|
||||
|
||||
public_key_b64, signature_b64 = parts
|
||||
|
||||
# Verify we can decode both parts
|
||||
public_key_bytes = base64.b64decode(public_key_b64)
|
||||
signature_bytes = base64.b64decode(signature_b64)
|
||||
|
||||
assert len(public_key_bytes) > 0
|
||||
assert len(signature_bytes) > 0
|
||||
|
||||
# Verify the signature is valid for the challenge
|
||||
challenge = f"{method},{path}?ts={timestamp}"
|
||||
loaded_private_key = auth.load_private_key()
|
||||
public_key = loaded_private_key.public_key()
|
||||
|
||||
# This should not raise an exception
|
||||
public_key.verify(signature_bytes, challenge.encode())
|
||||
Loading…
Reference in New Issue
Block a user