Add tests and improve auth header

This commit is contained in:
ParthSareen 2025-09-22 14:30:47 -07:00
parent 17fb82e215
commit 9aeb362580
3 changed files with 103 additions and 22 deletions

View File

@ -5,12 +5,11 @@
# "ollama", # "ollama",
# ] # ]
# /// # ///
import os
from typing import Union from typing import Union
from rich import print from rich import print
from ollama import Client, WebCrawlResponse, WebSearchResponse from ollama import WebCrawlResponse, WebSearchResponse, chat, web_crawl, web_search
def format_tool_results(results: Union[WebSearchResponse, WebCrawlResponse]): def format_tool_results(results: Union[WebSearchResponse, WebCrawlResponse]):
@ -49,15 +48,17 @@ def format_tool_results(results: Union[WebSearchResponse, WebCrawlResponse]):
return '\n'.join(output).rstrip() return '\n'.join(output).rstrip()
client = Client(headers={'Authorization': (os.getenv('OLLAMA_API_KEY'))}) # Set OLLAMA_API_KEY in the environment variable or use the headers parameter to set the authorization header
available_tools = {'web_search': client.web_search, 'web_crawl': client.web_crawl} # client = Client(headers={'Authorization': 'Bearer <OLLAMA_API_KEY>'})
available_tools = {'web_search': web_search, 'web_crawl': web_crawl}
query = "ollama's new engine" query = "ollama's new engine"
print('Query: ', query) print('Query: ', query)
messages = [{'role': 'user', 'content': query}] messages = [{'role': 'user', 'content': query}]
while True: while True:
response = client.chat(model='qwen3', messages=messages, tools=[client.web_search, client.web_crawl], think=True) response = chat(model='qwen3', messages=messages, tools=[web_search, web_crawl], think=True)
if response.message.thinking: if response.message.thinking:
print('Thinking: ') print('Thinking: ')
print(response.message.thinking + '\n\n') print(response.message.thinking + '\n\n')

View File

@ -94,23 +94,25 @@ class BaseClient:
`kwargs` are passed to the httpx client. `kwargs` are passed to the httpx client.
""" """
headers = {
k.lower(): v
for k, v in {
**(headers or {}),
'Content-Type': 'application/json',
'Accept': 'application/json',
'User-Agent': f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}',
}.items()
if v is not None
}
api_key = os.getenv('OLLAMA_API_KEY', None) api_key = os.getenv('OLLAMA_API_KEY', None)
if not headers.get('authorization') and api_key:
headers['authorization'] = f'Bearer {api_key}'
self._client = client( self._client = client(
base_url=_parse_host(host or os.getenv('OLLAMA_HOST')), base_url=_parse_host(host or os.getenv('OLLAMA_HOST')),
follow_redirects=follow_redirects, follow_redirects=follow_redirects,
timeout=timeout, timeout=timeout,
# Lowercase all headers to ensure override headers=headers,
headers={
k.lower(): v
for k, v in {
**(headers or {}),
'Content-Type': 'application/json',
'Accept': 'application/json',
'User-Agent': f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}',
'Authorization': f'Bearer {api_key}' if api_key else '',
}.items()
},
**kwargs, **kwargs,
) )
@ -644,9 +646,8 @@ class Client(BaseClient):
Raises: Raises:
ValueError: If OLLAMA_API_KEY environment variable is not set ValueError: If OLLAMA_API_KEY environment variable is not set
""" """
api_key = os.getenv('OLLAMA_API_KEY') if not self._client.headers.get('authorization', '').startswith('Bearer '):
if not api_key: raise ValueError('Authorization header with Bearer token is required for web search')
raise ValueError('OLLAMA_API_KEY environment variable is required for web search')
return self._request( return self._request(
WebSearchResponse, WebSearchResponse,
@ -670,9 +671,8 @@ class Client(BaseClient):
Raises: Raises:
ValueError: If OLLAMA_API_KEY environment variable is not set ValueError: If OLLAMA_API_KEY environment variable is not set
""" """
api_key = os.getenv('OLLAMA_API_KEY') if not self._client.headers.get('authorization', '').startswith('Bearer '):
if not api_key: raise ValueError('Authorization header with Bearer token is required for web fetch')
raise ValueError('OLLAMA_API_KEY environment variable is required for web fetch')
return self._request( return self._request(
WebCrawlResponse, WebCrawlResponse,

View File

@ -1195,3 +1195,83 @@ async def test_arbitrary_roles_accepted_in_message_request_async(monkeypatch: py
client = AsyncClient() client = AsyncClient()
await client.chat(model='llama3.1', messages=[{'role': 'somerandomrole', 'content': "I'm ok with you adding any role message now!"}, {'role': 'user', 'content': 'Hello world!'}]) await client.chat(model='llama3.1', messages=[{'role': 'somerandomrole', 'content': "I'm ok with you adding any role message now!"}, {'role': 'user', 'content': 'Hello world!'}])
def test_client_web_search_requires_bearer_auth_header(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delenv('OLLAMA_API_KEY', raising=False)
client = Client()
with pytest.raises(ValueError, match='Authorization header with Bearer token is required for web search'):
client.web_search(['test query'])
def test_client_web_crawl_requires_bearer_auth_header(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delenv('OLLAMA_API_KEY', raising=False)
client = Client()
with pytest.raises(ValueError, match='Authorization header with Bearer token is required for web fetch'):
client.web_crawl(['https://example.com'])
def _mock_request_web_search(self, cls, method, url, json=None, **kwargs):
assert method == 'POST'
assert url == 'https://ollama.com/api/web_search'
assert json is not None and 'queries' in json and 'max_results' in json
return httpxResponse(status_code=200, content='{"results": {}, "success": true}')
def _mock_request_web_crawl(self, cls, method, url, json=None, **kwargs):
assert method == 'POST'
assert url == 'https://ollama.com/api/web_crawl'
assert json is not None and 'urls' in json
return httpxResponse(status_code=200, content='{"results": {}, "success": true}')
def test_client_web_search_with_env_api_key(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv('OLLAMA_API_KEY', 'test-key')
monkeypatch.setattr(Client, '_request', _mock_request_web_search)
client = Client()
client.web_search(['what is ollama?'], max_results=2)
def test_client_web_crawl_with_env_api_key(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv('OLLAMA_API_KEY', 'test-key')
monkeypatch.setattr(Client, '_request', _mock_request_web_crawl)
client = Client()
client.web_crawl(['https://example.com'])
def test_client_web_search_with_explicit_bearer_header(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delenv('OLLAMA_API_KEY', raising=False)
monkeypatch.setattr(Client, '_request', _mock_request_web_search)
client = Client(headers={'Authorization': 'Bearer custom-token'})
client.web_search(['what is ollama?'], max_results=1)
def test_client_web_crawl_with_explicit_bearer_header(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delenv('OLLAMA_API_KEY', raising=False)
monkeypatch.setattr(Client, '_request', _mock_request_web_crawl)
client = Client(headers={'Authorization': 'Bearer custom-token'})
client.web_crawl(['https://example.com'])
def test_client_bearer_header_from_env(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv('OLLAMA_API_KEY', 'env-token')
client = Client()
assert client._client.headers['authorization'] == 'Bearer env-token'
def test_client_explicit_bearer_header_overrides_env(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv('OLLAMA_API_KEY', 'env-token')
monkeypatch.setattr(Client, '_request', _mock_request_web_search)
client = Client(headers={'Authorization': 'Bearer explicit-token'})
assert client._client.headers['authorization'] == 'Bearer explicit-token'
client.web_search(['override check'])