mirror of
https://github.com/ollama/ollama-python.git
synced 2026-06-11 10:44:46 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c919192116 | |||
| ab49a669cd | |||
| 16f344f635 | |||
| d0f71bc8b8 | |||
| b22c5fdabb | |||
| 4d0b81b37a |
+57
-13
@@ -1,80 +1,124 @@
|
||||
# Running Examples
|
||||
|
||||
Run the examples in this directory with:
|
||||
|
||||
```sh
|
||||
# Run example
|
||||
python3 examples/<example>.py
|
||||
|
||||
# or with uv
|
||||
uv run examples/<example>.py
|
||||
```
|
||||
|
||||
See [ollama/docs/api.md](https://github.com/ollama/ollama/blob/main/docs/api.md) for full API documentation
|
||||
|
||||
### Chat - Chat with a model
|
||||
|
||||
- [chat.py](chat.py)
|
||||
- [async-chat.py](async-chat.py)
|
||||
- [chat-stream.py](chat-stream.py) - Streamed outputs
|
||||
- [chat-with-history.py](chat-with-history.py) - Chat with model and maintain history of the conversation
|
||||
|
||||
|
||||
### Generate - Generate text with a model
|
||||
|
||||
- [generate.py](generate.py)
|
||||
- [async-generate.py](async-generate.py)
|
||||
- [generate-stream.py](generate-stream.py) - Streamed outputs
|
||||
- [fill-in-middle.py](fill-in-middle.py) - Given a prefix and suffix, fill in the middle
|
||||
|
||||
|
||||
### Tools/Function Calling - Call a function with a model
|
||||
|
||||
- [tools.py](tools.py) - Simple example of Tools/Function Calling
|
||||
- [async-tools.py](async-tools.py)
|
||||
- [multi-tool.py](multi-tool.py) - Using multiple tools, with thinking enabled
|
||||
|
||||
#### gpt-oss
|
||||
#### gpt-oss
|
||||
|
||||
- [gpt-oss-tools.py](gpt-oss-tools.py)
|
||||
- [gpt-oss-tools-stream.py](gpt-oss-tools-stream.py)
|
||||
- [gpt-oss-tools-stream.py](gpt-oss-tools-stream.py)
|
||||
- [gpt-oss-tools-browser.py](gpt-oss-tools-browser.py) - Using browser research tools with gpt-oss
|
||||
- [gpt-oss-tools-browser-stream.py](gpt-oss-tools-browser-stream.py) - Using browser research tools with gpt-oss, with streaming enabled
|
||||
|
||||
### Web search
|
||||
|
||||
An API key from Ollama's cloud service is required. You can create one [here](https://ollama.com/settings/keys).
|
||||
|
||||
```shell
|
||||
export OLLAMA_API_KEY="your_api_key_here"
|
||||
```
|
||||
|
||||
- [web-search.py](web-search.py)
|
||||
|
||||
#### MCP server
|
||||
|
||||
The MCP server can be used with an MCP client like Cursor, Cline, Codex, Open WebUI, Goose, and more.
|
||||
|
||||
```sh
|
||||
uv run examples/web-search-mcp.py
|
||||
```
|
||||
|
||||
Configuration to use with an MCP client:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"web_search": {
|
||||
"type": "stdio",
|
||||
"command": "uv",
|
||||
"args": ["run", "path/to/ollama-python/examples/web-search-mcp.py"],
|
||||
"env": { "OLLAMA_API_KEY": "your_api_key_here" }
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- [web-search-mcp.py](web-search-mcp.py)
|
||||
|
||||
### Multimodal with Images - Chat with a multimodal (image chat) model
|
||||
|
||||
- [multimodal-chat.py](multimodal-chat.py)
|
||||
- [multimodal-generate.py](multimodal-generate.py)
|
||||
|
||||
|
||||
### Structured Outputs - Generate structured outputs with a model
|
||||
|
||||
- [structured-outputs.py](structured-outputs.py)
|
||||
- [async-structured-outputs.py](async-structured-outputs.py)
|
||||
- [structured-outputs-image.py](structured-outputs-image.py)
|
||||
|
||||
|
||||
### Ollama List - List all downloaded models and their properties
|
||||
|
||||
- [list.py](list.py)
|
||||
|
||||
|
||||
### Ollama Show - Display model properties and capabilities
|
||||
|
||||
- [show.py](show.py)
|
||||
|
||||
|
||||
### Ollama ps - Show model status with CPU/GPU usage
|
||||
|
||||
- [ps.py](ps.py)
|
||||
|
||||
|
||||
### Ollama Pull - Pull a model from Ollama
|
||||
Requirement: `pip install tqdm`
|
||||
- [pull.py](pull.py)
|
||||
|
||||
Requirement: `pip install tqdm`
|
||||
|
||||
- [pull.py](pull.py)
|
||||
|
||||
### Ollama Create - Create a model from a Modelfile
|
||||
- [create.py](create.py)
|
||||
|
||||
- [create.py](create.py)
|
||||
|
||||
### Ollama Embed - Generate embeddings with a model
|
||||
|
||||
- [embed.py](embed.py)
|
||||
|
||||
|
||||
### Thinking - Enable thinking mode for a model
|
||||
|
||||
- [thinking.py](thinking.py)
|
||||
|
||||
### Thinking (generate) - Enable thinking mode for a model
|
||||
|
||||
- [thinking-generate.py](thinking-generate.py)
|
||||
|
||||
### Thinking (levels) - Choose the thinking level
|
||||
|
||||
- [thinking-levels.py](thinking-levels.py)
|
||||
|
||||
@@ -0,0 +1,500 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Protocol, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ollama import Client
|
||||
|
||||
|
||||
@dataclass
|
||||
class Page:
|
||||
url: str
|
||||
title: str
|
||||
text: str
|
||||
lines: List[str]
|
||||
links: Dict[int, str]
|
||||
fetched_at: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrowserStateData:
|
||||
page_stack: List[str] = field(default_factory=list)
|
||||
view_tokens: int = 1024
|
||||
url_to_page: Dict[str, Page] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSearchResult:
|
||||
title: str
|
||||
url: str
|
||||
content: Dict[str, str]
|
||||
|
||||
|
||||
class SearchClient(Protocol):
|
||||
def search(self, queries: List[str], max_results: Optional[int] = None): ...
|
||||
|
||||
|
||||
class CrawlClient(Protocol):
|
||||
def crawl(self, urls: List[str]): ...
|
||||
|
||||
|
||||
DEFAULT_VIEW_TOKENS = 1024
|
||||
CAPPED_TOOL_CONTENT_LEN = 8000
|
||||
|
||||
|
||||
def cap_tool_content(text: str) -> str:
|
||||
if not text:
|
||||
return text
|
||||
if len(text) <= CAPPED_TOOL_CONTENT_LEN:
|
||||
return text
|
||||
if CAPPED_TOOL_CONTENT_LEN <= 1:
|
||||
return text[:CAPPED_TOOL_CONTENT_LEN]
|
||||
return text[: CAPPED_TOOL_CONTENT_LEN - 1] + '…'
|
||||
|
||||
|
||||
def _safe_domain(u: str) -> str:
|
||||
try:
|
||||
parsed = urlparse(u)
|
||||
host = parsed.netloc or u
|
||||
return host.replace('www.', '') if host else u
|
||||
except Exception:
|
||||
return u
|
||||
|
||||
|
||||
class BrowserState:
|
||||
def __init__(self, initial_state: Optional[BrowserStateData] = None):
|
||||
self._data = initial_state or BrowserStateData(view_tokens=DEFAULT_VIEW_TOKENS)
|
||||
|
||||
def get_data(self) -> BrowserStateData:
|
||||
return self._data
|
||||
|
||||
def set_data(self, data: BrowserStateData) -> None:
|
||||
self._data = data
|
||||
|
||||
|
||||
class Browser:
|
||||
def __init__(
|
||||
self,
|
||||
initial_state: Optional[BrowserStateData] = None,
|
||||
client: Optional[Client] = None,
|
||||
):
|
||||
self.state = BrowserState(initial_state)
|
||||
self._client: Optional[Client] = client
|
||||
|
||||
def set_client(self, client: Client) -> None:
|
||||
self._client = client
|
||||
|
||||
def get_state(self) -> BrowserStateData:
|
||||
return self.state.get_data()
|
||||
|
||||
# ---- internal utils ----
|
||||
|
||||
def _save_page(self, page: Page) -> None:
|
||||
data = self.state.get_data()
|
||||
data.url_to_page[page.url] = page
|
||||
data.page_stack.append(page.url)
|
||||
self.state.set_data(data)
|
||||
|
||||
def _page_from_stack(self, url: str) -> Page:
|
||||
data = self.state.get_data()
|
||||
page = data.url_to_page.get(url)
|
||||
if not page:
|
||||
raise ValueError(f'Page not found for url {url}')
|
||||
return page
|
||||
|
||||
def _join_lines_with_numbers(self, lines: List[str]) -> str:
|
||||
result = []
|
||||
for i, line in enumerate(lines):
|
||||
result.append(f'L{i}: {line}')
|
||||
return '\n'.join(result)
|
||||
|
||||
def _wrap_lines(self, text: str, width: int = 80) -> List[str]:
|
||||
if width <= 0:
|
||||
width = 80
|
||||
src_lines = text.split('\n')
|
||||
wrapped: List[str] = []
|
||||
for line in src_lines:
|
||||
if line == '':
|
||||
wrapped.append('')
|
||||
elif len(line) <= width:
|
||||
wrapped.append(line)
|
||||
else:
|
||||
words = re.split(r'\s+', line)
|
||||
if not words:
|
||||
wrapped.append(line)
|
||||
continue
|
||||
curr = ''
|
||||
for w in words:
|
||||
test = (curr + ' ' + w) if curr else w
|
||||
if len(test) > width and curr:
|
||||
wrapped.append(curr)
|
||||
curr = w
|
||||
else:
|
||||
curr = test
|
||||
if curr:
|
||||
wrapped.append(curr)
|
||||
return wrapped
|
||||
|
||||
def _process_markdown_links(self, text: str) -> Tuple[str, Dict[int, str]]:
|
||||
links: Dict[int, str] = {}
|
||||
link_id = 0
|
||||
|
||||
multiline_pattern = re.compile(r'\[([^\]]+)\]\s*\n\s*\(([^)]+)\)')
|
||||
text = multiline_pattern.sub(lambda m: f'[{m.group(1)}]({m.group(2)})', text)
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
|
||||
link_pattern = re.compile(r'\[([^\]]+)\]\(([^)]+)\)')
|
||||
|
||||
def _repl(m: re.Match) -> str:
|
||||
nonlocal link_id
|
||||
link_text = m.group(1).strip()
|
||||
link_url = m.group(2).strip()
|
||||
domain = _safe_domain(link_url)
|
||||
formatted = f'【{link_id}†{link_text}†{domain}】'
|
||||
links[link_id] = link_url
|
||||
link_id += 1
|
||||
return formatted
|
||||
|
||||
processed = link_pattern.sub(_repl, text)
|
||||
return processed, links
|
||||
|
||||
def _get_end_loc(self, loc: int, num_lines: int, total_lines: int, lines: List[str]) -> int:
|
||||
if num_lines <= 0:
|
||||
txt = self._join_lines_with_numbers(lines[loc:])
|
||||
data = self.state.get_data()
|
||||
chars_per_token = 4
|
||||
max_chars = min(data.view_tokens * chars_per_token, len(txt))
|
||||
num_lines = txt[:max_chars].count('\n') + 1
|
||||
return min(loc + num_lines, total_lines)
|
||||
|
||||
def _display_page(self, page: Page, cursor: int, loc: int, num_lines: int) -> str:
|
||||
total_lines = len(page.lines) or 0
|
||||
if total_lines == 0:
|
||||
page.lines = ['']
|
||||
total_lines = 1
|
||||
|
||||
if loc != loc or loc < 0:
|
||||
loc = 0
|
||||
elif loc >= total_lines:
|
||||
loc = max(0, total_lines - 1)
|
||||
|
||||
end_loc = self._get_end_loc(loc, num_lines, total_lines, page.lines)
|
||||
|
||||
header = f'[{cursor}] {page.title}'
|
||||
header += f'({page.url})\n' if page.url else '\n'
|
||||
header += f'**viewing lines [{loc} - {end_loc - 1}] of {total_lines - 1}**\n\n'
|
||||
|
||||
body_lines = []
|
||||
for i in range(loc, end_loc):
|
||||
body_lines.append(f'L{i}: {page.lines[i]}')
|
||||
|
||||
return header + '\n'.join(body_lines)
|
||||
|
||||
def _build_search_results_page_collection(self, query: str, results: Dict[str, Any]) -> Page:
|
||||
page = Page(
|
||||
url=f'search_results_{query}',
|
||||
title=query,
|
||||
text='',
|
||||
lines=[],
|
||||
links={},
|
||||
fetched_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
tb = []
|
||||
tb.append('')
|
||||
tb.append('# Search Results')
|
||||
tb.append('')
|
||||
|
||||
link_idx = 0
|
||||
for query_results in results.get('results', {}).values():
|
||||
for result in query_results:
|
||||
domain = _safe_domain(result.get('url', ''))
|
||||
link_fmt = f'* 【{link_idx}†{result.get("title", "")}†{domain}】'
|
||||
tb.append(link_fmt)
|
||||
|
||||
raw_snip = result.get('content') or ''
|
||||
capped = (raw_snip[:400] + '…') if len(raw_snip) > 400 else raw_snip
|
||||
cleaned = re.sub(r'\d{40,}', lambda m: m.group(0)[:40] + '…', capped)
|
||||
cleaned = re.sub(r'\s{3,}', ' ', cleaned)
|
||||
tb.append(cleaned)
|
||||
page.links[link_idx] = result.get('url', '')
|
||||
link_idx += 1
|
||||
|
||||
page.text = '\n'.join(tb)
|
||||
page.lines = self._wrap_lines(page.text, 80)
|
||||
return page
|
||||
|
||||
def _build_search_result_page(self, result: WebSearchResult, link_idx: int) -> Page:
|
||||
page = Page(
|
||||
url=result.url,
|
||||
title=result.title,
|
||||
text='',
|
||||
lines=[],
|
||||
links={},
|
||||
fetched_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
link_fmt = f'【{link_idx}†{result.title}】\n'
|
||||
preview = link_fmt + f'URL: {result.url}\n'
|
||||
full_text = result.content.get('fullText', '') if result.content else ''
|
||||
preview += full_text[:300] + '\n\n'
|
||||
|
||||
if not full_text:
|
||||
page.links[link_idx] = result.url
|
||||
|
||||
if full_text:
|
||||
raw = f'URL: {result.url}\n{full_text}'
|
||||
processed, links = self._process_markdown_links(raw)
|
||||
page.text = processed
|
||||
page.links = links
|
||||
else:
|
||||
page.text = preview
|
||||
|
||||
page.lines = self._wrap_lines(page.text, 80)
|
||||
return page
|
||||
|
||||
def _build_page_from_fetch(self, requested_url: str, fetch_response: Dict[str, Any]) -> Page:
|
||||
page = Page(
|
||||
url=requested_url,
|
||||
title=requested_url,
|
||||
text='',
|
||||
lines=[],
|
||||
links={},
|
||||
fetched_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
for url, url_results in fetch_response.get('results', {}).items():
|
||||
if url_results:
|
||||
r0 = url_results[0]
|
||||
if r0.get('content'):
|
||||
page.text = r0['content']
|
||||
if r0.get('title'):
|
||||
page.title = r0['title']
|
||||
page.url = url
|
||||
break
|
||||
|
||||
if not page.text:
|
||||
page.text = 'No content could be extracted from this page.'
|
||||
else:
|
||||
page.text = f'URL: {page.url}\n{page.text}'
|
||||
|
||||
processed, links = self._process_markdown_links(page.text)
|
||||
page.text = processed
|
||||
page.links = links
|
||||
page.lines = self._wrap_lines(page.text, 80)
|
||||
return page
|
||||
|
||||
def _build_find_results_page(self, pattern: str, page: Page) -> Page:
|
||||
find_page = Page(
|
||||
url=f'find_results_{pattern}',
|
||||
title=f'Find results for text: `{pattern}` in `{page.title}`',
|
||||
text='',
|
||||
lines=[],
|
||||
links={},
|
||||
fetched_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
max_results = 50
|
||||
num_show_lines = 4
|
||||
pattern_lower = pattern.lower()
|
||||
|
||||
result_chunks: List[str] = []
|
||||
line_idx = 0
|
||||
while line_idx < len(page.lines):
|
||||
line = page.lines[line_idx]
|
||||
if pattern_lower not in line.lower():
|
||||
line_idx += 1
|
||||
continue
|
||||
|
||||
end_line = min(line_idx + num_show_lines, len(page.lines))
|
||||
snippet = '\n'.join(page.lines[line_idx:end_line])
|
||||
link_fmt = f'【{len(result_chunks)}†match at L{line_idx}】'
|
||||
result_chunks.append(f'{link_fmt}\n{snippet}')
|
||||
|
||||
if len(result_chunks) >= max_results:
|
||||
break
|
||||
line_idx += num_show_lines
|
||||
|
||||
if not result_chunks:
|
||||
find_page.text = f'No `find` results for pattern: `{pattern}`'
|
||||
else:
|
||||
find_page.text = '\n\n'.join(result_chunks)
|
||||
|
||||
find_page.lines = self._wrap_lines(find_page.text, 80)
|
||||
return find_page
|
||||
|
||||
def search(self, *, query: str, topn: int = 5) -> Dict[str, Any]:
|
||||
if not self._client:
|
||||
raise RuntimeError('Client not provided')
|
||||
|
||||
resp = self._client.web_search(query, max_results=topn)
|
||||
|
||||
normalized: Dict[str, Any] = {'results': {}}
|
||||
rows: List[Dict[str, str]] = []
|
||||
for item in resp.results:
|
||||
content = item.content or ''
|
||||
rows.append(
|
||||
{
|
||||
'title': item.title,
|
||||
'url': item.url,
|
||||
'content': content,
|
||||
}
|
||||
)
|
||||
normalized['results'][query] = rows
|
||||
|
||||
search_page = self._build_search_results_page_collection(query, normalized)
|
||||
self._save_page(search_page)
|
||||
cursor = len(self.get_state().page_stack) - 1
|
||||
|
||||
for query_results in normalized.get('results', {}).values():
|
||||
for i, r in enumerate(query_results):
|
||||
ws = WebSearchResult(
|
||||
title=r.get('title', ''),
|
||||
url=r.get('url', ''),
|
||||
content={'fullText': r.get('content', '') or ''},
|
||||
)
|
||||
result_page = self._build_search_result_page(ws, i + 1)
|
||||
data = self.get_state()
|
||||
data.url_to_page[result_page.url] = result_page
|
||||
self.state.set_data(data)
|
||||
|
||||
page_text = self._display_page(search_page, cursor, loc=0, num_lines=-1)
|
||||
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
|
||||
|
||||
def open(
|
||||
self,
|
||||
*,
|
||||
id: Optional[str | int] = None,
|
||||
cursor: int = -1,
|
||||
loc: int = 0,
|
||||
num_lines: int = -1,
|
||||
) -> Dict[str, Any]:
|
||||
if not self._client:
|
||||
raise RuntimeError('Client not provided')
|
||||
|
||||
state = self.get_state()
|
||||
|
||||
if isinstance(id, str):
|
||||
url = id
|
||||
if url in state.url_to_page:
|
||||
self._save_page(state.url_to_page[url])
|
||||
cursor = len(self.get_state().page_stack) - 1
|
||||
page_text = self._display_page(state.url_to_page[url], cursor, loc, num_lines)
|
||||
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
|
||||
|
||||
fetch_response = self._client.web_fetch(url)
|
||||
normalized: Dict[str, Any] = {
|
||||
'results': {
|
||||
url: [
|
||||
{
|
||||
'title': fetch_response.title or url,
|
||||
'url': url,
|
||||
'content': fetch_response.content or '',
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
new_page = self._build_page_from_fetch(url, normalized)
|
||||
self._save_page(new_page)
|
||||
cursor = len(self.get_state().page_stack) - 1
|
||||
page_text = self._display_page(new_page, cursor, loc, num_lines)
|
||||
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
|
||||
|
||||
# Resolve current page from stack only if needed (int id or no id)
|
||||
page: Optional[Page] = None
|
||||
if cursor >= 0:
|
||||
if state.page_stack:
|
||||
if cursor >= len(state.page_stack):
|
||||
cursor = max(0, len(state.page_stack) - 1)
|
||||
page = self._page_from_stack(state.page_stack[cursor])
|
||||
else:
|
||||
page = None
|
||||
else:
|
||||
if state.page_stack:
|
||||
page = self._page_from_stack(state.page_stack[-1])
|
||||
|
||||
if isinstance(id, int):
|
||||
if not page:
|
||||
raise RuntimeError('No current page to resolve link from')
|
||||
|
||||
link_url = page.links.get(id)
|
||||
if not link_url:
|
||||
err = Page(
|
||||
url=f'invalid_link_{id}',
|
||||
title=f'No link with id {id} on `{page.title}`',
|
||||
text='',
|
||||
lines=[],
|
||||
links={},
|
||||
fetched_at=datetime.utcnow(),
|
||||
)
|
||||
available = sorted(page.links.keys())
|
||||
available_list = ', '.join(map(str, available)) if available else '(none)'
|
||||
err.text = '\n'.join(
|
||||
[
|
||||
f'Requested link id: {id}',
|
||||
f'Current page: {page.title}',
|
||||
f'Available link ids on this page: {available_list}',
|
||||
'',
|
||||
'Tips:',
|
||||
'- To scroll this page, call browser_open with { loc, num_lines } (no id).',
|
||||
'- To open a result from a search results page, pass the correct { cursor, id }.',
|
||||
]
|
||||
)
|
||||
err.lines = self._wrap_lines(err.text, 80)
|
||||
self._save_page(err)
|
||||
cursor = len(self.get_state().page_stack) - 1
|
||||
page_text = self._display_page(err, cursor, 0, -1)
|
||||
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
|
||||
|
||||
new_page = state.url_to_page.get(link_url)
|
||||
if not new_page:
|
||||
fetch_response = self._client.web_fetch(link_url)
|
||||
normalized: Dict[str, Any] = {
|
||||
'results': {
|
||||
link_url: [
|
||||
{
|
||||
'title': fetch_response.title or link_url,
|
||||
'url': link_url,
|
||||
'content': fetch_response.content or '',
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
new_page = self._build_page_from_fetch(link_url, normalized)
|
||||
|
||||
self._save_page(new_page)
|
||||
cursor = len(self.get_state().page_stack) - 1
|
||||
page_text = self._display_page(new_page, cursor, loc, num_lines)
|
||||
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
|
||||
|
||||
if not page:
|
||||
raise RuntimeError('No current page to display')
|
||||
|
||||
cur = self.get_state()
|
||||
cur.page_stack.append(page.url)
|
||||
self.state.set_data(cur)
|
||||
cursor = len(cur.page_stack) - 1
|
||||
page_text = self._display_page(page, cursor, loc, num_lines)
|
||||
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
|
||||
|
||||
def find(self, *, pattern: str, cursor: int = -1) -> Dict[str, Any]:
|
||||
state = self.get_state()
|
||||
if cursor == -1:
|
||||
if not state.page_stack:
|
||||
raise RuntimeError('No pages to search in')
|
||||
page = self._page_from_stack(state.page_stack[-1])
|
||||
cursor = len(state.page_stack) - 1
|
||||
else:
|
||||
if cursor < 0 or cursor >= len(state.page_stack):
|
||||
cursor = max(0, min(cursor, len(state.page_stack) - 1))
|
||||
page = self._page_from_stack(state.page_stack[cursor])
|
||||
|
||||
find_page = self._build_find_results_page(pattern, page)
|
||||
self._save_page(find_page)
|
||||
new_cursor = len(self.get_state().page_stack) - 1
|
||||
|
||||
page_text = self._display_page(find_page, new_cursor, 0, -1)
|
||||
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
|
||||
@@ -0,0 +1,116 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.11"
|
||||
# dependencies = [
|
||||
# "mcp",
|
||||
# "rich",
|
||||
# "ollama",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
MCP stdio server exposing Ollama web_search and web_fetch as tools.
|
||||
|
||||
Environment:
|
||||
- OLLAMA_API_KEY (required): if set, will be used as Authorization header.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
|
||||
from ollama import Client
|
||||
|
||||
try:
|
||||
# Preferred high-level API (if available)
|
||||
from mcp.server.fastmcp import FastMCP # type: ignore
|
||||
|
||||
_FASTMCP_AVAILABLE = True
|
||||
except Exception:
|
||||
_FASTMCP_AVAILABLE = False
|
||||
|
||||
if not _FASTMCP_AVAILABLE:
|
||||
# Fallback to the low-level stdio server API
|
||||
from mcp.server import Server # type: ignore
|
||||
from mcp.server.stdio import stdio_server # type: ignore
|
||||
|
||||
|
||||
client = Client()
|
||||
|
||||
|
||||
def _web_search_impl(query: str, max_results: int = 3) -> Dict[str, Any]:
|
||||
res = client.web_search(query=query, max_results=max_results)
|
||||
return res.model_dump()
|
||||
|
||||
|
||||
def _web_fetch_impl(url: str) -> Dict[str, Any]:
|
||||
res = client.web_fetch(url=url)
|
||||
return res.model_dump()
|
||||
|
||||
|
||||
if _FASTMCP_AVAILABLE:
|
||||
app = FastMCP('ollama-search-fetch')
|
||||
|
||||
@app.tool()
|
||||
def web_search(query: str, max_results: int = 3) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform a web search using Ollama's hosted search API.
|
||||
|
||||
Args:
|
||||
query: The search query to run.
|
||||
max_results: Maximum results to return (default: 3).
|
||||
|
||||
Returns:
|
||||
JSON-serializable dict matching ollama.WebSearchResponse.model_dump()
|
||||
"""
|
||||
|
||||
return _web_search_impl(query=query, max_results=max_results)
|
||||
|
||||
@app.tool()
|
||||
def web_fetch(url: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch the content of a web page for the provided URL.
|
||||
|
||||
Args:
|
||||
url: The absolute URL to fetch.
|
||||
|
||||
Returns:
|
||||
JSON-serializable dict matching ollama.WebFetchResponse.model_dump()
|
||||
"""
|
||||
|
||||
return _web_fetch_impl(url=url)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run()
|
||||
|
||||
else:
|
||||
server = Server('ollama-search-fetch') # type: ignore[name-defined]
|
||||
|
||||
@server.tool() # type: ignore[attr-defined]
|
||||
async def web_search(query: str, max_results: int = 3) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform a web search using Ollama's hosted search API.
|
||||
|
||||
Args:
|
||||
query: The search query to run.
|
||||
max_results: Maximum results to return (default: 3).
|
||||
"""
|
||||
|
||||
return await asyncio.to_thread(_web_search_impl, query, max_results)
|
||||
|
||||
@server.tool() # type: ignore[attr-defined]
|
||||
async def web_fetch(url: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch the content of a web page for the provided URL.
|
||||
|
||||
Args:
|
||||
url: The absolute URL to fetch.
|
||||
"""
|
||||
|
||||
return await asyncio.to_thread(_web_fetch_impl, url)
|
||||
|
||||
async def _main() -> None:
|
||||
async with stdio_server() as (read, write): # type: ignore[name-defined]
|
||||
await server.run(read, write) # type: ignore[attr-defined]
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(_main())
|
||||
@@ -0,0 +1,85 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.11"
|
||||
# dependencies = [
|
||||
# "rich",
|
||||
# "ollama",
|
||||
# ]
|
||||
# ///
|
||||
from typing import Union
|
||||
|
||||
from rich import print
|
||||
|
||||
from ollama import WebFetchResponse, WebSearchResponse, chat, web_fetch, web_search
|
||||
|
||||
|
||||
def format_tool_results(
|
||||
results: Union[WebSearchResponse, WebFetchResponse],
|
||||
user_search: str,
|
||||
):
|
||||
output = []
|
||||
if isinstance(results, WebSearchResponse):
|
||||
output.append(f'Search results for "{user_search}":')
|
||||
for result in results.results:
|
||||
output.append(f'{result.title}' if result.title else f'{result.content}')
|
||||
output.append(f' URL: {result.url}')
|
||||
output.append(f' Content: {result.content}')
|
||||
output.append('')
|
||||
return '\n'.join(output).rstrip()
|
||||
|
||||
elif isinstance(results, WebFetchResponse):
|
||||
output.append(f'Fetch results for "{user_search}":')
|
||||
output.extend(
|
||||
[
|
||||
f'Title: {results.title}',
|
||||
f'URL: {user_search}' if user_search else '',
|
||||
f'Content: {results.content}',
|
||||
]
|
||||
)
|
||||
if results.links:
|
||||
output.append(f'Links: {", ".join(results.links)}')
|
||||
output.append('')
|
||||
return '\n'.join(output).rstrip()
|
||||
|
||||
|
||||
# client = Client(headers={'Authorization': f"Bearer {os.getenv('OLLAMA_API_KEY')}"} if api_key else None)
|
||||
available_tools = {'web_search': web_search, 'web_fetch': web_fetch}
|
||||
|
||||
query = "what is ollama's new engine"
|
||||
print('Query: ', query)
|
||||
|
||||
messages = [{'role': 'user', 'content': query}]
|
||||
while True:
|
||||
response = chat(model='qwen3', messages=messages, tools=[web_search, web_fetch], think=True)
|
||||
if response.message.thinking:
|
||||
print('Thinking: ')
|
||||
print(response.message.thinking + '\n\n')
|
||||
if response.message.content:
|
||||
print('Content: ')
|
||||
print(response.message.content + '\n')
|
||||
|
||||
messages.append(response.message)
|
||||
|
||||
if response.message.tool_calls:
|
||||
for tool_call in response.message.tool_calls:
|
||||
function_to_call = available_tools.get(tool_call.function.name)
|
||||
if function_to_call:
|
||||
args = tool_call.function.arguments
|
||||
result: Union[WebSearchResponse, WebFetchResponse] = function_to_call(**args)
|
||||
print('Result from tool call name:', tool_call.function.name, 'with arguments:')
|
||||
print(args)
|
||||
print()
|
||||
|
||||
user_search = args.get('query', '') or args.get('url', '')
|
||||
formatted_tool_results = format_tool_results(result, user_search=user_search)
|
||||
|
||||
print(formatted_tool_results[:300])
|
||||
print()
|
||||
|
||||
# caps the result at ~2000 tokens
|
||||
messages.append({'role': 'tool', 'content': formatted_tool_results[: 2000 * 4], 'tool_name': tool_call.function.name})
|
||||
else:
|
||||
print(f'Tool {tool_call.function.name} not found')
|
||||
messages.append({'role': 'tool', 'content': f'Tool {tool_call.function.name} not found', 'tool_name': tool_call.function.name})
|
||||
else:
|
||||
# no more tool calls, we can stop the loop
|
||||
break
|
||||
@@ -15,6 +15,8 @@ from ollama._types import (
|
||||
ShowResponse,
|
||||
StatusResponse,
|
||||
Tool,
|
||||
WebFetchResponse,
|
||||
WebSearchResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -35,6 +37,8 @@ __all__ = [
|
||||
'ShowResponse',
|
||||
'StatusResponse',
|
||||
'Tool',
|
||||
'WebFetchResponse',
|
||||
'WebSearchResponse',
|
||||
]
|
||||
|
||||
_client = Client()
|
||||
@@ -51,3 +55,5 @@ list = _client.list
|
||||
copy = _client.copy
|
||||
show = _client.show
|
||||
ps = _client.ps
|
||||
web_search = _client.web_search
|
||||
web_fetch = _client.web_fetch
|
||||
|
||||
-100
@@ -1,100 +0,0 @@
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
|
||||
class OllamaAuth:
|
||||
def __init__(self, key_path: Optional[str] = None):
|
||||
"""Initialize the OllamaAuth class.
|
||||
|
||||
Args:
|
||||
key_path: Optional path to the private key file. If not provided,
|
||||
defaults to ~/.ollama/id_ed25519
|
||||
"""
|
||||
if key_path is None:
|
||||
home = str(Path.home())
|
||||
self.key_path = os.path.join(home, '.ollama', 'id_ed25519')
|
||||
else:
|
||||
# Expand ~ and environment variables in the path
|
||||
self.key_path = os.path.expanduser(os.path.expandvars(key_path))
|
||||
|
||||
def load_private_key(self):
|
||||
"""Read and load the private key.
|
||||
|
||||
Returns:
|
||||
The loaded Ed25519 private key.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the key file doesn't exist
|
||||
ValueError: If the key file is invalid
|
||||
"""
|
||||
try:
|
||||
with open(self.key_path, 'rb') as f:
|
||||
private_key_data = f.read()
|
||||
|
||||
private_key = serialization.load_ssh_private_key(
|
||||
private_key_data,
|
||||
password=None,
|
||||
)
|
||||
return private_key
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Could not find Ollama private key at {self.key_path}. Please generate one using: ssh-keygen -t ed25519 -f ~/.ollama/id_ed25519 -N ''")
|
||||
except Exception as e:
|
||||
raise ValueError(f'Invalid private key at {self.key_path}: {e!s}')
|
||||
|
||||
def get_public_key_b64(self, private_key):
|
||||
"""Get the base64 encoded public key.
|
||||
|
||||
Args:
|
||||
private_key: The Ed25519 private key
|
||||
|
||||
Returns:
|
||||
Base64 encoded public key string
|
||||
"""
|
||||
# Get the public key in OpenSSH format and extract the second field (base64-encoded key)
|
||||
public_key = private_key.public_key()
|
||||
openssh_pub = (
|
||||
public_key.public_bytes(
|
||||
encoding=serialization.Encoding.OpenSSH,
|
||||
format=serialization.PublicFormat.OpenSSH,
|
||||
)
|
||||
.decode('utf-8')
|
||||
.strip()
|
||||
)
|
||||
parts = openssh_pub.split(' ')
|
||||
if len(parts) < 2:
|
||||
raise ValueError('Malformed OpenSSH public key')
|
||||
public_key_b64 = parts[1]
|
||||
return public_key_b64
|
||||
|
||||
def sign_request(self, method: str, path: str):
|
||||
"""Sign an HTTP request.
|
||||
|
||||
Args:
|
||||
method: The HTTP method (e.g. 'GET', 'POST')
|
||||
path: The request path (e.g. '/api/chat')
|
||||
|
||||
Returns:
|
||||
A tuple of (auth_token, timestamp) where auth_token is the
|
||||
authorization header value and timestamp is the request timestamp.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the key file doesn't exist
|
||||
ValueError: If the key file is invalid
|
||||
"""
|
||||
timestamp = str(int(time.time()))
|
||||
path_with_ts = f'{path}&ts={timestamp}' if '?' in path else f'{path}?ts={timestamp}'
|
||||
challenge = f'{method},{path_with_ts}'
|
||||
|
||||
private_key = self.load_private_key()
|
||||
signature = private_key.sign(challenge.encode())
|
||||
|
||||
public_key_b64 = self.get_public_key_b64(private_key)
|
||||
|
||||
auth_token = f'{public_key_b64}:{base64.b64encode(signature).decode("utf-8")}'
|
||||
|
||||
return auth_token, timestamp
|
||||
+114
-50
@@ -25,7 +25,6 @@ from typing import (
|
||||
import anyio
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
|
||||
from ollama._auth import OllamaAuth
|
||||
from ollama._utils import convert_function_to_tool
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
@@ -67,6 +66,10 @@ from ollama._types import (
|
||||
ShowResponse,
|
||||
StatusResponse,
|
||||
Tool,
|
||||
WebFetchRequest,
|
||||
WebFetchResponse,
|
||||
WebSearchRequest,
|
||||
WebSearchResponse,
|
||||
)
|
||||
|
||||
T = TypeVar('T')
|
||||
@@ -81,7 +84,6 @@ class BaseClient:
|
||||
follow_redirects: bool = True,
|
||||
timeout: Any = None,
|
||||
headers: Optional[Mapping[str, str]] = None,
|
||||
auth_key_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -89,48 +91,31 @@ class BaseClient:
|
||||
except for the following:
|
||||
- `follow_redirects`: True
|
||||
- `timeout`: None
|
||||
- `auth_key_path`: Optional path to the ed25519 private key for authentication
|
||||
`kwargs` are passed to the httpx client.
|
||||
"""
|
||||
self._auth = OllamaAuth(auth_key_path)
|
||||
|
||||
headers = {
|
||||
k.lower(): v
|
||||
for k, v in {
|
||||
**(headers or {}),
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
'User-Agent': f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}',
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
api_key = os.getenv('OLLAMA_API_KEY', None)
|
||||
if not headers.get('authorization') and api_key:
|
||||
headers['authorization'] = f'Bearer {api_key}'
|
||||
|
||||
self._client = client(
|
||||
base_url=_parse_host(host or os.getenv('OLLAMA_HOST')),
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
# Lowercase all headers to ensure override
|
||||
headers={
|
||||
k.lower(): v
|
||||
for k, v in {
|
||||
**(headers or {}),
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
'User-Agent': f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}',
|
||||
}.items()
|
||||
},
|
||||
headers=headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _prepare_request(self, method: str, path: str, **kwargs) -> Dict[str, Any]:
|
||||
if self._auth:
|
||||
url = str(self._client.build_request(method, path).url)
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
full_path = parsed.path
|
||||
if parsed.query:
|
||||
full_path = f'{full_path}?{parsed.query}'
|
||||
|
||||
auth_token, timestamp = self._auth.sign_request(method, full_path)
|
||||
|
||||
if 'headers' not in kwargs:
|
||||
kwargs['headers'] = {}
|
||||
kwargs['headers']['Authorization'] = auth_token
|
||||
|
||||
if '?' in path:
|
||||
path = f'{path}&ts={timestamp}'
|
||||
else:
|
||||
path = f'{path}?ts={timestamp}'
|
||||
|
||||
return {'method': method, 'url': path, **kwargs}
|
||||
|
||||
|
||||
CONNECTION_ERROR_MESSAGE = 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'
|
||||
|
||||
@@ -179,18 +164,14 @@ class Client(BaseClient):
|
||||
def _request(
|
||||
self,
|
||||
cls: Type[T],
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
*args,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[T, Iterator[T]]:
|
||||
request_params = self._prepare_request(method, path, **kwargs)
|
||||
|
||||
if stream:
|
||||
|
||||
def inner():
|
||||
with self._client.stream(**request_params) as r:
|
||||
with self._client.stream(*args, **kwargs) as r:
|
||||
try:
|
||||
r.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
@@ -205,7 +186,7 @@ class Client(BaseClient):
|
||||
|
||||
return inner()
|
||||
|
||||
return cls(**self._request_raw(**request_params).json())
|
||||
return cls(**self._request_raw(*args, **kwargs).json())
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
@@ -652,6 +633,54 @@ class Client(BaseClient):
|
||||
'/api/ps',
|
||||
)
|
||||
|
||||
def web_search(self, query: str, max_results: int = 3) -> WebSearchResponse:
|
||||
"""
|
||||
Performs a web search
|
||||
|
||||
Args:
|
||||
query: The query to search for
|
||||
max_results: The maximum number of results to return (default: 3)
|
||||
|
||||
Returns:
|
||||
WebSearchResponse with the search results
|
||||
Raises:
|
||||
ValueError: If OLLAMA_API_KEY environment variable is not set
|
||||
"""
|
||||
if not self._client.headers.get('authorization', '').startswith('Bearer '):
|
||||
raise ValueError('Authorization header with Bearer token is required for web search')
|
||||
|
||||
return self._request(
|
||||
WebSearchResponse,
|
||||
'POST',
|
||||
'https://ollama.com/api/web_search',
|
||||
json=WebSearchRequest(
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
def web_fetch(self, url: str) -> WebFetchResponse:
|
||||
"""
|
||||
Fetches the content of a web page for the provided URL.
|
||||
|
||||
Args:
|
||||
url: The URL to fetch
|
||||
|
||||
Returns:
|
||||
WebFetchResponse with the fetched result
|
||||
"""
|
||||
if not self._client.headers.get('authorization', '').startswith('Bearer '):
|
||||
raise ValueError('Authorization header with Bearer token is required for web fetch')
|
||||
|
||||
return self._request(
|
||||
WebFetchResponse,
|
||||
'POST',
|
||||
'https://ollama.com/api/web_fetch',
|
||||
json=WebFetchRequest(
|
||||
url=url,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
class AsyncClient(BaseClient):
|
||||
def __init__(self, host: Optional[str] = None, **kwargs) -> None:
|
||||
@@ -697,19 +726,14 @@ class AsyncClient(BaseClient):
|
||||
async def _request(
|
||||
self,
|
||||
cls: Type[T],
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
*args,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[T, AsyncIterator[T]]:
|
||||
"""Make a request with optional authentication."""
|
||||
request_params = self._prepare_request(method, path, **kwargs)
|
||||
|
||||
if stream:
|
||||
|
||||
async def inner():
|
||||
async with self._client.stream(**request_params) as r:
|
||||
async with self._client.stream(*args, **kwargs) as r:
|
||||
try:
|
||||
r.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
@@ -724,7 +748,47 @@ class AsyncClient(BaseClient):
|
||||
|
||||
return inner()
|
||||
|
||||
return cls(**(await self._request_raw(**request_params)).json())
|
||||
return cls(**(await self._request_raw(*args, **kwargs)).json())
|
||||
|
||||
async def web_search(self, query: str, max_results: int = 3) -> WebSearchResponse:
|
||||
"""
|
||||
Performs a web search
|
||||
|
||||
Args:
|
||||
query: The query to search for
|
||||
max_results: The maximum number of results to return (default: 3)
|
||||
|
||||
Returns:
|
||||
WebSearchResponse with the search results
|
||||
"""
|
||||
return await self._request(
|
||||
WebSearchResponse,
|
||||
'POST',
|
||||
'https://ollama.com/api/web_search',
|
||||
json=WebSearchRequest(
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
async def web_fetch(self, url: str) -> WebFetchResponse:
|
||||
"""
|
||||
Fetches the content of a web page for the provided URL.
|
||||
|
||||
Args:
|
||||
url: The URL to fetch
|
||||
|
||||
Returns:
|
||||
WebFetchResponse with the fetched result
|
||||
"""
|
||||
return await self._request(
|
||||
WebFetchResponse,
|
||||
'POST',
|
||||
'https://ollama.com/api/web_fetch',
|
||||
json=WebFetchRequest(
|
||||
url=url,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
@overload
|
||||
async def generate(
|
||||
|
||||
@@ -541,6 +541,31 @@ class ProcessResponse(SubscriptableBaseModel):
|
||||
models: Sequence[Model]
|
||||
|
||||
|
||||
class WebSearchRequest(SubscriptableBaseModel):
|
||||
query: str
|
||||
max_results: Optional[int] = None
|
||||
|
||||
|
||||
class WebSearchResult(SubscriptableBaseModel):
|
||||
content: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class WebFetchRequest(SubscriptableBaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class WebSearchResponse(SubscriptableBaseModel):
|
||||
results: Sequence[WebSearchResult]
|
||||
|
||||
|
||||
class WebFetchResponse(SubscriptableBaseModel):
|
||||
title: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
links: Optional[Sequence[str]] = None
|
||||
|
||||
|
||||
class RequestError(Exception):
|
||||
"""
|
||||
Common class for request errors.
|
||||
|
||||
@@ -1195,3 +1195,83 @@ async def test_arbitrary_roles_accepted_in_message_request_async(monkeypatch: py
|
||||
client = AsyncClient()
|
||||
|
||||
await client.chat(model='llama3.1', messages=[{'role': 'somerandomrole', 'content': "I'm ok with you adding any role message now!"}, {'role': 'user', 'content': 'Hello world!'}])
|
||||
|
||||
|
||||
def test_client_web_search_requires_bearer_auth_header(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.delenv('OLLAMA_API_KEY', raising=False)
|
||||
|
||||
client = Client()
|
||||
|
||||
with pytest.raises(ValueError, match='Authorization header with Bearer token is required for web search'):
|
||||
client.web_search('test query')
|
||||
|
||||
|
||||
def test_client_web_fetch_requires_bearer_auth_header(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.delenv('OLLAMA_API_KEY', raising=False)
|
||||
|
||||
client = Client()
|
||||
|
||||
with pytest.raises(ValueError, match='Authorization header with Bearer token is required for web fetch'):
|
||||
client.web_fetch('https://example.com')
|
||||
|
||||
|
||||
def _mock_request_web_search(self, cls, method, url, json=None, **kwargs):
|
||||
assert method == 'POST'
|
||||
assert url == 'https://ollama.com/api/web_search'
|
||||
assert json is not None and 'query' in json and 'max_results' in json
|
||||
return httpxResponse(status_code=200, content='{"results": {}, "success": true}')
|
||||
|
||||
|
||||
def _mock_request_web_fetch(self, cls, method, url, json=None, **kwargs):
|
||||
assert method == 'POST'
|
||||
assert url == 'https://ollama.com/api/web_fetch'
|
||||
assert json is not None and 'url' in json
|
||||
return httpxResponse(status_code=200, content='{"results": {}, "success": true}')
|
||||
|
||||
|
||||
def test_client_web_search_with_env_api_key(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv('OLLAMA_API_KEY', 'test-key')
|
||||
monkeypatch.setattr(Client, '_request', _mock_request_web_search)
|
||||
|
||||
client = Client()
|
||||
client.web_search('what is ollama?', max_results=2)
|
||||
|
||||
|
||||
def test_client_web_fetch_with_env_api_key(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv('OLLAMA_API_KEY', 'test-key')
|
||||
monkeypatch.setattr(Client, '_request', _mock_request_web_fetch)
|
||||
|
||||
client = Client()
|
||||
client.web_fetch('https://example.com')
|
||||
|
||||
|
||||
def test_client_web_search_with_explicit_bearer_header(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.delenv('OLLAMA_API_KEY', raising=False)
|
||||
monkeypatch.setattr(Client, '_request', _mock_request_web_search)
|
||||
|
||||
client = Client(headers={'Authorization': 'Bearer custom-token'})
|
||||
client.web_search('what is ollama?', max_results=1)
|
||||
|
||||
|
||||
def test_client_web_fetch_with_explicit_bearer_header(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.delenv('OLLAMA_API_KEY', raising=False)
|
||||
monkeypatch.setattr(Client, '_request', _mock_request_web_fetch)
|
||||
|
||||
client = Client(headers={'Authorization': 'Bearer custom-token'})
|
||||
client.web_fetch('https://example.com')
|
||||
|
||||
|
||||
def test_client_bearer_header_from_env(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv('OLLAMA_API_KEY', 'env-token')
|
||||
|
||||
client = Client()
|
||||
assert client._client.headers['authorization'] == 'Bearer env-token'
|
||||
|
||||
|
||||
def test_client_explicit_bearer_header_overrides_env(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv('OLLAMA_API_KEY', 'env-token')
|
||||
monkeypatch.setattr(Client, '_request', _mock_request_web_search)
|
||||
|
||||
client = Client(headers={'Authorization': 'Bearer explicit-token'})
|
||||
assert client._client.headers['authorization'] == 'Bearer explicit-token'
|
||||
client.web_search('override check')
|
||||
|
||||
Reference in New Issue
Block a user