mirror of
https://github.com/ollama/ollama-python.git
synced 2026-01-14 06:07:17 +08:00
Add tests and improve auth header
This commit is contained in:
parent
17fb82e215
commit
9aeb362580
@ -5,12 +5,11 @@
|
|||||||
# "ollama",
|
# "ollama",
|
||||||
# ]
|
# ]
|
||||||
# ///
|
# ///
|
||||||
import os
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from rich import print
|
from rich import print
|
||||||
|
|
||||||
from ollama import Client, WebCrawlResponse, WebSearchResponse
|
from ollama import WebCrawlResponse, WebSearchResponse, chat, web_crawl, web_search
|
||||||
|
|
||||||
|
|
||||||
def format_tool_results(results: Union[WebSearchResponse, WebCrawlResponse]):
|
def format_tool_results(results: Union[WebSearchResponse, WebCrawlResponse]):
|
||||||
@ -49,15 +48,17 @@ def format_tool_results(results: Union[WebSearchResponse, WebCrawlResponse]):
|
|||||||
return '\n'.join(output).rstrip()
|
return '\n'.join(output).rstrip()
|
||||||
|
|
||||||
|
|
||||||
client = Client(headers={'Authorization': (os.getenv('OLLAMA_API_KEY'))})
|
# Set OLLAMA_API_KEY in the environment variable or use the headers parameter to set the authorization header
|
||||||
available_tools = {'web_search': client.web_search, 'web_crawl': client.web_crawl}
|
# client = Client(headers={'Authorization': 'Bearer <OLLAMA_API_KEY>'})
|
||||||
|
|
||||||
|
available_tools = {'web_search': web_search, 'web_crawl': web_crawl}
|
||||||
|
|
||||||
query = "ollama's new engine"
|
query = "ollama's new engine"
|
||||||
print('Query: ', query)
|
print('Query: ', query)
|
||||||
|
|
||||||
messages = [{'role': 'user', 'content': query}]
|
messages = [{'role': 'user', 'content': query}]
|
||||||
while True:
|
while True:
|
||||||
response = client.chat(model='qwen3', messages=messages, tools=[client.web_search, client.web_crawl], think=True)
|
response = chat(model='qwen3', messages=messages, tools=[web_search, web_crawl], think=True)
|
||||||
if response.message.thinking:
|
if response.message.thinking:
|
||||||
print('Thinking: ')
|
print('Thinking: ')
|
||||||
print(response.message.thinking + '\n\n')
|
print(response.message.thinking + '\n\n')
|
||||||
|
|||||||
@ -94,23 +94,25 @@ class BaseClient:
|
|||||||
`kwargs` are passed to the httpx client.
|
`kwargs` are passed to the httpx client.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
k.lower(): v
|
||||||
|
for k, v in {
|
||||||
|
**(headers or {}),
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Accept': 'application/json',
|
||||||
|
'User-Agent': f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}',
|
||||||
|
}.items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
api_key = os.getenv('OLLAMA_API_KEY', None)
|
api_key = os.getenv('OLLAMA_API_KEY', None)
|
||||||
|
if not headers.get('authorization') and api_key:
|
||||||
|
headers['authorization'] = f'Bearer {api_key}'
|
||||||
|
|
||||||
self._client = client(
|
self._client = client(
|
||||||
base_url=_parse_host(host or os.getenv('OLLAMA_HOST')),
|
base_url=_parse_host(host or os.getenv('OLLAMA_HOST')),
|
||||||
follow_redirects=follow_redirects,
|
follow_redirects=follow_redirects,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
# Lowercase all headers to ensure override
|
headers=headers,
|
||||||
headers={
|
|
||||||
k.lower(): v
|
|
||||||
for k, v in {
|
|
||||||
**(headers or {}),
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Accept': 'application/json',
|
|
||||||
'User-Agent': f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}',
|
|
||||||
'Authorization': f'Bearer {api_key}' if api_key else '',
|
|
||||||
}.items()
|
|
||||||
},
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -644,9 +646,8 @@ class Client(BaseClient):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If OLLAMA_API_KEY environment variable is not set
|
ValueError: If OLLAMA_API_KEY environment variable is not set
|
||||||
"""
|
"""
|
||||||
api_key = os.getenv('OLLAMA_API_KEY')
|
if not self._client.headers.get('authorization', '').startswith('Bearer '):
|
||||||
if not api_key:
|
raise ValueError('Authorization header with Bearer token is required for web search')
|
||||||
raise ValueError('OLLAMA_API_KEY environment variable is required for web search')
|
|
||||||
|
|
||||||
return self._request(
|
return self._request(
|
||||||
WebSearchResponse,
|
WebSearchResponse,
|
||||||
@ -670,9 +671,8 @@ class Client(BaseClient):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If OLLAMA_API_KEY environment variable is not set
|
ValueError: If OLLAMA_API_KEY environment variable is not set
|
||||||
"""
|
"""
|
||||||
api_key = os.getenv('OLLAMA_API_KEY')
|
if not self._client.headers.get('authorization', '').startswith('Bearer '):
|
||||||
if not api_key:
|
raise ValueError('Authorization header with Bearer token is required for web fetch')
|
||||||
raise ValueError('OLLAMA_API_KEY environment variable is required for web fetch')
|
|
||||||
|
|
||||||
return self._request(
|
return self._request(
|
||||||
WebCrawlResponse,
|
WebCrawlResponse,
|
||||||
|
|||||||
@ -1195,3 +1195,83 @@ async def test_arbitrary_roles_accepted_in_message_request_async(monkeypatch: py
|
|||||||
client = AsyncClient()
|
client = AsyncClient()
|
||||||
|
|
||||||
await client.chat(model='llama3.1', messages=[{'role': 'somerandomrole', 'content': "I'm ok with you adding any role message now!"}, {'role': 'user', 'content': 'Hello world!'}])
|
await client.chat(model='llama3.1', messages=[{'role': 'somerandomrole', 'content': "I'm ok with you adding any role message now!"}, {'role': 'user', 'content': 'Hello world!'}])
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_web_search_requires_bearer_auth_header(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.delenv('OLLAMA_API_KEY', raising=False)
|
||||||
|
|
||||||
|
client = Client()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match='Authorization header with Bearer token is required for web search'):
|
||||||
|
client.web_search(['test query'])
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_web_crawl_requires_bearer_auth_header(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.delenv('OLLAMA_API_KEY', raising=False)
|
||||||
|
|
||||||
|
client = Client()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match='Authorization header with Bearer token is required for web fetch'):
|
||||||
|
client.web_crawl(['https://example.com'])
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_request_web_search(self, cls, method, url, json=None, **kwargs):
|
||||||
|
assert method == 'POST'
|
||||||
|
assert url == 'https://ollama.com/api/web_search'
|
||||||
|
assert json is not None and 'queries' in json and 'max_results' in json
|
||||||
|
return httpxResponse(status_code=200, content='{"results": {}, "success": true}')
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_request_web_crawl(self, cls, method, url, json=None, **kwargs):
|
||||||
|
assert method == 'POST'
|
||||||
|
assert url == 'https://ollama.com/api/web_crawl'
|
||||||
|
assert json is not None and 'urls' in json
|
||||||
|
return httpxResponse(status_code=200, content='{"results": {}, "success": true}')
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_web_search_with_env_api_key(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.setenv('OLLAMA_API_KEY', 'test-key')
|
||||||
|
monkeypatch.setattr(Client, '_request', _mock_request_web_search)
|
||||||
|
|
||||||
|
client = Client()
|
||||||
|
client.web_search(['what is ollama?'], max_results=2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_web_crawl_with_env_api_key(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.setenv('OLLAMA_API_KEY', 'test-key')
|
||||||
|
monkeypatch.setattr(Client, '_request', _mock_request_web_crawl)
|
||||||
|
|
||||||
|
client = Client()
|
||||||
|
client.web_crawl(['https://example.com'])
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_web_search_with_explicit_bearer_header(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.delenv('OLLAMA_API_KEY', raising=False)
|
||||||
|
monkeypatch.setattr(Client, '_request', _mock_request_web_search)
|
||||||
|
|
||||||
|
client = Client(headers={'Authorization': 'Bearer custom-token'})
|
||||||
|
client.web_search(['what is ollama?'], max_results=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_web_crawl_with_explicit_bearer_header(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.delenv('OLLAMA_API_KEY', raising=False)
|
||||||
|
monkeypatch.setattr(Client, '_request', _mock_request_web_crawl)
|
||||||
|
|
||||||
|
client = Client(headers={'Authorization': 'Bearer custom-token'})
|
||||||
|
client.web_crawl(['https://example.com'])
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_bearer_header_from_env(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.setenv('OLLAMA_API_KEY', 'env-token')
|
||||||
|
|
||||||
|
client = Client()
|
||||||
|
assert client._client.headers['authorization'] == 'Bearer env-token'
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_explicit_bearer_header_overrides_env(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
monkeypatch.setenv('OLLAMA_API_KEY', 'env-token')
|
||||||
|
monkeypatch.setattr(Client, '_request', _mock_request_web_search)
|
||||||
|
|
||||||
|
client = Client(headers={'Authorization': 'Bearer explicit-token'})
|
||||||
|
assert client._client.headers['authorization'] == 'Bearer explicit-token'
|
||||||
|
client.web_search(['override check'])
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user