From db7e2a980e408ab7c03c85f88c4e27a842d6e5d2 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Thu, 18 Sep 2025 13:28:08 -0700 Subject: [PATCH] auth: add auth through key signing --- ollama/_auth.py | 100 ++++++++++++++++++++++++++++++++++++++++++++++ ollama/_client.py | 47 ++++++++++++++++++---- 2 files changed, 140 insertions(+), 7 deletions(-) create mode 100644 ollama/_auth.py diff --git a/ollama/_auth.py b/ollama/_auth.py new file mode 100644 index 0000000..6d8f975 --- /dev/null +++ b/ollama/_auth.py @@ -0,0 +1,100 @@ +import base64 +import os +import time +from pathlib import Path +from typing import Optional + +from cryptography.hazmat.primitives import serialization + + +class OllamaAuth: + def __init__(self, key_path: Optional[str] = None): + """Initialize the OllamaAuth class. + + Args: + key_path: Optional path to the private key file. If not provided, + defaults to ~/.ollama/id_ed25519 + """ + if key_path is None: + home = str(Path.home()) + self.key_path = os.path.join(home, '.ollama', 'id_ed25519') + else: + # Expand ~ and environment variables in the path + self.key_path = os.path.expanduser(os.path.expandvars(key_path)) + + def load_private_key(self): + """Read and load the private key. + + Returns: + The loaded Ed25519 private key. + + Raises: + FileNotFoundError: If the key file doesn't exist + ValueError: If the key file is invalid + """ + try: + with open(self.key_path, 'rb') as f: + private_key_data = f.read() + + private_key = serialization.load_ssh_private_key( + private_key_data, + password=None, + ) + return private_key + except FileNotFoundError: + raise FileNotFoundError(f"Could not find Ollama private key at {self.key_path}. Please generate one using: ssh-keygen -t ed25519 -f ~/.ollama/id_ed25519 -N ''") + except Exception as e: + raise ValueError(f'Invalid private key at {self.key_path}: {e!s}') + + def get_public_key_b64(self, private_key): + """Get the base64 encoded public key. + + Args: + private_key: The Ed25519 private key + + Returns: + Base64 encoded public key string + """ + # Get the public key in OpenSSH format and extract the second field (base64-encoded key) + public_key = private_key.public_key() + openssh_pub = ( + public_key.public_bytes( + encoding=serialization.Encoding.OpenSSH, + format=serialization.PublicFormat.OpenSSH, + ) + .decode('utf-8') + .strip() + ) + parts = openssh_pub.split(' ') + if len(parts) < 2: + raise ValueError('Malformed OpenSSH public key') + public_key_b64 = parts[1] + return public_key_b64 + + def sign_request(self, method: str, path: str): + """Sign an HTTP request. + + Args: + method: The HTTP method (e.g. 'GET', 'POST') + path: The request path (e.g. '/api/chat') + + Returns: + A tuple of (auth_token, timestamp) where auth_token is the + authorization header value and timestamp is the request timestamp. + + Raises: + FileNotFoundError: If the key file doesn't exist + ValueError: If the key file is invalid + """ + timestamp = str(int(time.time())) + path_with_ts = f'{path}&ts={timestamp}' if '?' in path else f'{path}?ts={timestamp}' + challenge = f'{method},{path_with_ts}' + + private_key = self.load_private_key() + signature = private_key.sign(challenge.encode()) + + public_key_b64 = self.get_public_key_b64(private_key) + + auth_token = f'{public_key_b64}:{base64.b64encode(signature).decode("utf-8")}' + + return auth_token, timestamp diff --git a/ollama/_client.py b/ollama/_client.py index 4bcc1b1..c5e61e2 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -25,6 +25,7 @@ from typing import ( import anyio from pydantic.json_schema import JsonSchemaValue +from ollama._auth import OllamaAuth from ollama._utils import convert_function_to_tool if sys.version_info < (3, 9): @@ -80,6 +81,7 @@ class BaseClient: follow_redirects: bool = True, timeout: Any = None, headers: Optional[Mapping[str, str]] = None, + auth_key_path: Optional[str] = None, **kwargs, ) -> None: """ @@ -87,9 +89,10 @@ class BaseClient: except for the following: - `follow_redirects`: True - `timeout`: None + - `auth_key_path`: Optional path to the ed25519 private key for authentication `kwargs` are passed to the httpx client. """ - + self._auth = OllamaAuth(auth_key_path) self._client = client( base_url=_parse_host(host or os.getenv('OLLAMA_HOST')), follow_redirects=follow_redirects, @@ -107,6 +110,27 @@ class BaseClient: **kwargs, ) + def _prepare_request(self, method: str, path: str, **kwargs) -> Dict[str, Any]: + if self._auth: + url = str(self._client.build_request(method, path).url) + parsed = urllib.parse.urlparse(url) + full_path = parsed.path + if parsed.query: + full_path = f'{full_path}?{parsed.query}' + + auth_token, timestamp = self._auth.sign_request(method, full_path) + + if 'headers' not in kwargs: + kwargs['headers'] = {} + kwargs['headers']['Authorization'] = auth_token + + if '?' in path: + path = f'{path}&ts={timestamp}' + else: + path = f'{path}?ts={timestamp}' + + return {'method': method, 'url': path, **kwargs} + CONNECTION_ERROR_MESSAGE = 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download' @@ -155,14 +179,18 @@ class Client(BaseClient): def _request( self, cls: Type[T], - *args, + method: str, + path: str, + *, stream: bool = False, **kwargs, ) -> Union[T, Iterator[T]]: + request_params = self._prepare_request(method, path, **kwargs) + if stream: def inner(): - with self._client.stream(*args, **kwargs) as r: + with self._client.stream(**request_params) as r: try: r.raise_for_status() except httpx.HTTPStatusError as e: @@ -177,7 +205,7 @@ class Client(BaseClient): return inner() - return cls(**self._request_raw(*args, **kwargs).json()) + return cls(**self._request_raw(**request_params).json()) @overload def generate( @@ -669,14 +697,19 @@ class AsyncClient(BaseClient): async def _request( self, cls: Type[T], - *args, + method: str, + path: str, + *, stream: bool = False, **kwargs, ) -> Union[T, AsyncIterator[T]]: + """Make a request with optional authentication.""" + request_params = self._prepare_request(method, path, **kwargs) + if stream: async def inner(): - async with self._client.stream(*args, **kwargs) as r: + async with self._client.stream(**request_params) as r: try: r.raise_for_status() except httpx.HTTPStatusError as e: @@ -691,7 +724,7 @@ class AsyncClient(BaseClient): return inner() - return cls(**(await self._request_raw(*args, **kwargs)).json()) + return cls(**(await self._request_raw(**request_params)).json()) @overload async def generate(