Compare commits

...

7 Commits

Author SHA1 Message Date
Jeffrey Morgan dbccf192ac Add image generation support (#616)
test / test (push) Has been cancelled
test / lint (push) Has been cancelled
2026-01-23 00:33:52 -08:00
dependabot[bot] 60e7b2f9ce build(deps): bump actions/checkout from 5 to 6 (#602)
test / test (push) Has been cancelled
test / lint (push) Has been cancelled
Bumps [actions/checkout](https://github.com/actions/checkout) from 5 to 6.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-29 12:03:13 -08:00
Parth Sareen d1d704050b client: expose resource cleanup methods (#444)
test / test (push) Has been cancelled
test / lint (push) Has been cancelled
2025-12-10 17:09:19 -08:00
Eden Chan 115792583e readme: add cloud models usage and examples (#595)
test / test (push) Has been cancelled
test / lint (push) Has been cancelled
2025-11-13 15:03:58 -08:00
Parth Sareen 0008226fda client/types: add logprobs support (#601)
test / test (push) Waiting to run
test / lint (push) Waiting to run
2025-11-12 18:08:42 -08:00
Parth Sareen 9ddd5f0182 examples: fix model web search (#589)
test / test (push) Has been cancelled
test / lint (push) Has been cancelled
2025-09-24 15:53:51 -07:00
Parth Sareen d967f048d9 examples: gpt oss browser tool (#588)
---------

Co-authored-by: nicole pardal <nicolepardall@gmail.com>
2025-09-24 15:40:53 -07:00
16 changed files with 1109 additions and 383 deletions
+1 -1
View File
@@ -13,7 +13,7 @@ jobs:
id-token: write
contents: write
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: actions/setup-python@v6
- uses: astral-sh/setup-uv@v5
with:
+2 -2
View File
@@ -10,7 +10,7 @@ jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: astral-sh/setup-uv@v5
with:
enable-cache: true
@@ -19,7 +19,7 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: actions/setup-python@v6
- uses: astral-sh/setup-uv@v5
with:
+76 -1
View File
@@ -50,6 +50,82 @@ for chunk in stream:
print(chunk['message']['content'], end='', flush=True)
```
## Cloud Models
Run larger models by offloading to Ollamas cloud while keeping your local workflow.
- Supported models: `deepseek-v3.1:671b-cloud`, `gpt-oss:20b-cloud`, `gpt-oss:120b-cloud`, `kimi-k2:1t-cloud`, `qwen3-coder:480b-cloud`, `kimi-k2-thinking` See [Ollama Models - Cloud](https://ollama.com/search?c=cloud) for more information
### Run via local Ollama
1) Sign in (one-time):
```
ollama signin
```
2) Pull a cloud model:
```
ollama pull gpt-oss:120b-cloud
```
3) Make a request:
```python
from ollama import Client
client = Client()
messages = [
{
'role': 'user',
'content': 'Why is the sky blue?',
},
]
for part in client.chat('gpt-oss:120b-cloud', messages=messages, stream=True):
print(part.message.content, end='', flush=True)
```
### Cloud API (ollama.com)
Access cloud models directly by pointing the client at `https://ollama.com`.
1) Create an API key from [ollama.com](https://ollama.com/settings/keys) , then set:
```
export OLLAMA_API_KEY=your_api_key
```
2) (Optional) List models available via the API:
```
curl https://ollama.com/api/tags
```
3) Generate a response via the cloud API:
```python
import os
from ollama import Client
client = Client(
host='https://ollama.com',
headers={'Authorization': 'Bearer ' + os.environ.get('OLLAMA_API_KEY')}
)
messages = [
{
'role': 'user',
'content': 'Why is the sky blue?',
},
]
for part in client.chat('gpt-oss:120b', messages=messages, stream=True):
print(part.message.content, end='', flush=True)
```
## Custom client
A custom client can be created by instantiating `Client` or `AsyncClient` from `ollama`.
@@ -174,7 +250,6 @@ ollama.embed(model='gemma3', input=['The sky is blue because of rayleigh scatter
ollama.ps()
```
## Errors
Errors are raised if requests return an error status or if an error is detected while streaming.
+7 -2
View File
@@ -36,8 +36,6 @@ See [ollama/docs/api.md](https://github.com/ollama/ollama/blob/main/docs/api.md)
- [gpt-oss-tools.py](gpt-oss-tools.py)
- [gpt-oss-tools-stream.py](gpt-oss-tools-stream.py)
- [gpt-oss-tools-browser.py](gpt-oss-tools-browser.py) - Using browser research tools with gpt-oss
- [gpt-oss-tools-browser-stream.py](gpt-oss-tools-browser-stream.py) - Using browser research tools with gpt-oss, with streaming enabled
### Web search
@@ -48,6 +46,7 @@ export OLLAMA_API_KEY="your_api_key_here"
```
- [web-search.py](web-search.py)
- [web-search-gpt-oss.py](web-search-gpt-oss.py) - Using browser research tools with gpt-oss
#### MCP server
@@ -79,6 +78,12 @@ Configuration to use with an MCP client:
- [multimodal-chat.py](multimodal-chat.py)
- [multimodal-generate.py](multimodal-generate.py)
### Image Generation (Experimental) - Generate images with a model
> **Note:** Image generation is experimental and currently only available on macOS.
- [generate-image.py](generate-image.py)
### Structured Outputs - Generate structured outputs with a model
- [structured-outputs.py](structured-outputs.py)
+31
View File
@@ -0,0 +1,31 @@
from typing import Iterable
import ollama
def print_logprobs(logprobs: Iterable[dict], label: str) -> None:
print(f'\n{label}:')
for entry in logprobs:
token = entry.get('token', '')
logprob = entry.get('logprob')
print(f' token={token!r:<12} logprob={logprob:.3f}')
for alt in entry.get('top_logprobs', []):
if alt['token'] != token:
print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})')
messages = [
{
'role': 'user',
'content': 'hi! be concise.',
},
]
response = ollama.chat(
model='gemma3',
messages=messages,
logprobs=True,
top_logprobs=3,
)
print('Chat response:', response['message']['content'])
print_logprobs(response.get('logprobs', []), 'chat logprobs')
+2 -1
View File
@@ -15,7 +15,8 @@ messages = [
},
{
'role': 'assistant',
'content': 'The weather in Tokyo is typically warm and humid during the summer months, with temperatures often exceeding 30°C (86°F). The city experiences a rainy season from June to September, with heavy rainfall and occasional typhoons. Winter is mild, with temperatures rarely dropping below freezing. The city is known for its high-tech and vibrant culture, with many popular tourist attractions such as the Tokyo Tower, Senso-ji Temple, and the bustling Shibuya district.',
'content': """The weather in Tokyo is typically warm and humid during the summer months, with temperatures often exceeding 30°C (86°F). The city experiences a rainy season from June to September, with heavy rainfall and occasional typhoons. Winter is mild, with temperatures
rarely dropping below freezing. The city is known for its high-tech and vibrant culture, with many popular tourist attractions such as the Tokyo Tower, Senso-ji Temple, and the bustling Shibuya district.""",
},
]
+18
View File
@@ -0,0 +1,18 @@
# Image generation is experimental and currently only available on macOS
import base64
from ollama import generate
prompt = 'a sunset over mountains'
print(f'Prompt: {prompt}')
for response in generate(model='x/z-image-turbo', prompt=prompt, stream=True):
if response.image:
# Final response contains the image
with open('output.png', 'wb') as f:
f.write(base64.b64decode(response.image))
print('\nImage saved to output.png')
elif response.total:
# Progress update
print(f'Progress: {response.completed or 0}/{response.total}', end='\r')
+24
View File
@@ -0,0 +1,24 @@
from typing import Iterable
import ollama
def print_logprobs(logprobs: Iterable[dict], label: str) -> None:
print(f'\n{label}:')
for entry in logprobs:
token = entry.get('token', '')
logprob = entry.get('logprob')
print(f' token={token!r:<12} logprob={logprob:.3f}')
for alt in entry.get('top_logprobs', []):
if alt['token'] != token:
print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})')
response = ollama.generate(
model='gemma3',
prompt='hi! be concise.',
logprobs=True,
top_logprobs=3,
)
print('Generate response:', response['response'])
print_logprobs(response.get('logprobs', []), 'generate logprobs')
-198
View File
@@ -1,198 +0,0 @@
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "gpt-oss",
# "ollama",
# "rich",
# ]
# ///
import asyncio
import json
from typing import Iterator, Optional
from gpt_oss.tools.simple_browser import ExaBackend, SimpleBrowserTool
from openai_harmony import Author, Role, TextContent
from openai_harmony import Message as HarmonyMessage
from rich import print
from ollama import Client
from ollama._types import ChatResponse
_backend = ExaBackend(source='web')
_browser_tool = SimpleBrowserTool(backend=_backend)
def heading(text):
print(text)
print('=' * (len(text) + 3))
async def _browser_search_async(query: str, topn: int = 10, source: str | None = None) -> str:
# map Ollama message to Harmony format
harmony_message = HarmonyMessage(
author=Author(role=Role.USER),
content=[TextContent(text=json.dumps({'query': query, 'topn': topn}))],
recipient='browser.search',
)
result_text: str = ''
async for response in _browser_tool._process(harmony_message):
if response.content:
for content in response.content:
if isinstance(content, TextContent):
result_text += content.text
return result_text or f'No results for query: {query}'
async def _browser_open_async(id: int | str = -1, cursor: int = -1, loc: int = -1, num_lines: int = -1, *, view_source: bool = False, source: str | None = None) -> str:
payload = {'id': id, 'cursor': cursor, 'loc': loc, 'num_lines': num_lines, 'view_source': view_source, 'source': source}
harmony_message = HarmonyMessage(
author=Author(role=Role.USER),
content=[TextContent(text=json.dumps(payload))],
recipient='browser.open',
)
result_text: str = ''
async for response in _browser_tool._process(harmony_message):
if response.content:
for content in response.content:
if isinstance(content, TextContent):
result_text += content.text
return result_text or f'Could not open: {id}'
async def _browser_find_async(pattern: str, cursor: int = -1) -> str:
payload = {'pattern': pattern, 'cursor': cursor}
harmony_message = HarmonyMessage(
author=Author(role=Role.USER),
content=[TextContent(text=json.dumps(payload))],
recipient='browser.find',
)
result_text: str = ''
async for response in _browser_tool._process(harmony_message):
if response.content:
for content in response.content:
if isinstance(content, TextContent):
result_text += content.text
return result_text or f'Pattern not found: {pattern}'
def browser_search(query: str, topn: int = 10, source: Optional[str] = None) -> str:
return asyncio.run(_browser_search_async(query=query, topn=topn, source=source))
def browser_open(id: int | str = -1, cursor: int = -1, loc: int = -1, num_lines: int = -1, *, view_source: bool = False, source: Optional[str] = None) -> str:
return asyncio.run(_browser_open_async(id=id, cursor=cursor, loc=loc, num_lines=num_lines, view_source=view_source, source=source))
def browser_find(pattern: str, cursor: int = -1) -> str:
return asyncio.run(_browser_find_async(pattern=pattern, cursor=cursor))
# Schema definitions for each browser tool
browser_search_schema = {
'type': 'function',
'function': {
'name': 'browser.search',
},
}
browser_open_schema = {
'type': 'function',
'function': {
'name': 'browser.open',
},
}
browser_find_schema = {
'type': 'function',
'function': {
'name': 'browser.find',
},
}
available_tools = {
'browser.search': browser_search,
'browser.open': browser_open,
'browser.find': browser_find,
}
model = 'gpt-oss:20b'
print('Model: ', model, '\n')
prompt = 'What is Ollama?'
print('You: ', prompt, '\n')
messages = [{'role': 'user', 'content': prompt}]
client = Client()
# gpt-oss can call tools while "thinking"
# a loop is needed to call the tools and get the results
final = True
while True:
response_stream: Iterator[ChatResponse] = client.chat(
model=model,
messages=messages,
tools=[browser_search_schema, browser_open_schema, browser_find_schema],
options={'num_ctx': 8192}, # 8192 is the recommended lower limit for the context window
stream=True,
)
tool_calls = []
thinking = ''
content = ''
for chunk in response_stream:
if chunk.message.tool_calls:
tool_calls.extend(chunk.message.tool_calls)
if chunk.message.content:
if not (chunk.message.thinking or chunk.message.thinking == '') and final:
heading('\n\nFinal result: ')
final = False
print(chunk.message.content, end='', flush=True)
if chunk.message.thinking:
thinking += chunk.message.thinking
print(chunk.message.thinking, end='', flush=True)
if thinking != '':
messages.append({'role': 'assistant', 'content': thinking, 'tool_calls': tool_calls})
print()
if tool_calls:
for tool_call in tool_calls:
tool_name = tool_call.function.name
args = tool_call.function.arguments or {}
function_to_call = available_tools.get(tool_name)
if function_to_call:
heading(f'\nCalling tool: {tool_name}')
if args:
print(f'Arguments: {args}')
try:
result = function_to_call(**args)
print(f'Tool result: {result[:200]}')
if len(result) > 200:
heading('... [truncated]')
print()
result_message = {'role': 'tool', 'content': result, 'tool_name': tool_name}
messages.append(result_message)
except Exception as e:
err = f'Error from {tool_name}: {e}'
print(err)
messages.append({'role': 'tool', 'content': err, 'tool_name': tool_name})
else:
print(f'Tool {tool_name} not found')
else:
# no more tool calls, we can stop the loop
break
-175
View File
@@ -1,175 +0,0 @@
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "gpt-oss",
# "ollama",
# "rich",
# ]
# ///
import asyncio
import json
from typing import Optional
from gpt_oss.tools.simple_browser import ExaBackend, SimpleBrowserTool
from openai_harmony import Author, Role, TextContent
from openai_harmony import Message as HarmonyMessage
from ollama import Client
_backend = ExaBackend(source='web')
_browser_tool = SimpleBrowserTool(backend=_backend)
def heading(text):
print(text)
print('=' * (len(text) + 3))
async def _browser_search_async(query: str, topn: int = 10, source: str | None = None) -> str:
# map Ollama message to Harmony format
harmony_message = HarmonyMessage(
author=Author(role=Role.USER),
content=[TextContent(text=json.dumps({'query': query, 'topn': topn}))],
recipient='browser.search',
)
result_text: str = ''
async for response in _browser_tool._process(harmony_message):
if response.content:
for content in response.content:
if isinstance(content, TextContent):
result_text += content.text
return result_text or f'No results for query: {query}'
async def _browser_open_async(id: int | str = -1, cursor: int = -1, loc: int = -1, num_lines: int = -1, *, view_source: bool = False, source: str | None = None) -> str:
payload = {'id': id, 'cursor': cursor, 'loc': loc, 'num_lines': num_lines, 'view_source': view_source, 'source': source}
harmony_message = HarmonyMessage(
author=Author(role=Role.USER),
content=[TextContent(text=json.dumps(payload))],
recipient='browser.open',
)
result_text: str = ''
async for response in _browser_tool._process(harmony_message):
if response.content:
for content in response.content:
if isinstance(content, TextContent):
result_text += content.text
return result_text or f'Could not open: {id}'
async def _browser_find_async(pattern: str, cursor: int = -1) -> str:
payload = {'pattern': pattern, 'cursor': cursor}
harmony_message = HarmonyMessage(
author=Author(role=Role.USER),
content=[TextContent(text=json.dumps(payload))],
recipient='browser.find',
)
result_text: str = ''
async for response in _browser_tool._process(harmony_message):
if response.content:
for content in response.content:
if isinstance(content, TextContent):
result_text += content.text
return result_text or f'Pattern not found: {pattern}'
def browser_search(query: str, topn: int = 10, source: Optional[str] = None) -> str:
return asyncio.run(_browser_search_async(query=query, topn=topn, source=source))
def browser_open(id: int | str = -1, cursor: int = -1, loc: int = -1, num_lines: int = -1, *, view_source: bool = False, source: Optional[str] = None) -> str:
return asyncio.run(_browser_open_async(id=id, cursor=cursor, loc=loc, num_lines=num_lines, view_source=view_source, source=source))
def browser_find(pattern: str, cursor: int = -1) -> str:
return asyncio.run(_browser_find_async(pattern=pattern, cursor=cursor))
# Schema definitions for each browser tool
browser_search_schema = {
'type': 'function',
'function': {
'name': 'browser.search',
},
}
browser_open_schema = {
'type': 'function',
'function': {
'name': 'browser.open',
},
}
browser_find_schema = {
'type': 'function',
'function': {
'name': 'browser.find',
},
}
available_tools = {
'browser.search': browser_search,
'browser.open': browser_open,
'browser.find': browser_find,
}
model = 'gpt-oss:20b'
print('Model: ', model, '\n')
prompt = 'What is Ollama?'
print('You: ', prompt, '\n')
messages = [{'role': 'user', 'content': prompt}]
client = Client()
while True:
response = client.chat(
model=model,
messages=messages,
tools=[browser_search_schema, browser_open_schema, browser_find_schema],
options={'num_ctx': 8192}, # 8192 is the recommended lower limit for the context window
)
if hasattr(response.message, 'thinking') and response.message.thinking:
heading('Thinking')
print(response.message.thinking.strip() + '\n')
if hasattr(response.message, 'content') and response.message.content:
heading('Assistant')
print(response.message.content.strip() + '\n')
# add message to chat history
messages.append(response.message)
if response.message.tool_calls:
for tool_call in response.message.tool_calls:
tool_name = tool_call.function.name
args = tool_call.function.arguments or {}
function_to_call = available_tools.get(tool_name)
if not function_to_call:
print(f'Unknown tool: {tool_name}')
continue
try:
result = function_to_call(**args)
heading(f'Tool: {tool_name}')
if args:
print(f'Arguments: {args}')
print(result[:200])
if len(result) > 200:
print('... [truncated]')
print()
messages.append({'role': 'tool', 'content': result, 'tool_name': tool_name})
except Exception as e:
err = f'Error from {tool_name}: {e}'
print(err)
messages.append({'role': 'tool', 'content': err, 'tool_name': tool_name})
else:
# break on no more tool calls
break
+99
View File
@@ -0,0 +1,99 @@
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "ollama",
# ]
# ///
from typing import Any, Dict, List
from web_search_gpt_oss_helper import Browser
from ollama import Client
def main() -> None:
client = Client()
browser = Browser(initial_state=None, client=client)
def browser_search(query: str, topn: int = 10) -> str:
return browser.search(query=query, topn=topn)['pageText']
def browser_open(id: int | str | None = None, cursor: int = -1, loc: int = -1, num_lines: int = -1) -> str:
return browser.open(id=id, cursor=cursor, loc=loc, num_lines=num_lines)['pageText']
def browser_find(pattern: str, cursor: int = -1, **_: Any) -> str:
return browser.find(pattern=pattern, cursor=cursor)['pageText']
browser_search_schema = {
'type': 'function',
'function': {
'name': 'browser.search',
},
}
browser_open_schema = {
'type': 'function',
'function': {
'name': 'browser.open',
},
}
browser_find_schema = {
'type': 'function',
'function': {
'name': 'browser.find',
},
}
available_tools = {
'browser.search': browser_search,
'browser.open': browser_open,
'browser.find': browser_find,
}
query = "what is ollama's new engine"
print('Prompt:', query, '\n')
messages: List[Dict[str, Any]] = [{'role': 'user', 'content': query}]
while True:
resp = client.chat(
model='gpt-oss:120b-cloud',
messages=messages,
tools=[browser_search_schema, browser_open_schema, browser_find_schema],
think=True,
)
if resp.message.thinking:
print('Thinking:\n========\n')
print(resp.message.thinking + '\n')
if resp.message.content:
print('Response:\n========\n')
print(resp.message.content + '\n')
messages.append(resp.message)
if not resp.message.tool_calls:
break
for tc in resp.message.tool_calls:
tool_name = tc.function.name
args = tc.function.arguments or {}
print(f'Tool name: {tool_name}, args: {args}')
fn = available_tools.get(tool_name)
if not fn:
messages.append({'role': 'tool', 'content': f'Tool {tool_name} not found', 'tool_name': tool_name})
continue
try:
result_text = fn(**args)
print('Result: ', result_text[:200] + '...')
except Exception as e:
result_text = f'Error from {tool_name}: {e}'
messages.append({'role': 'tool', 'content': result_text, 'tool_name': tool_name})
if __name__ == '__main__':
main()
+514
View File
@@ -0,0 +1,514 @@
from __future__ import annotations
import re
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional, Protocol, Tuple
from urllib.parse import urlparse
from ollama import Client
@dataclass
class Page:
url: str
title: str
text: str
lines: List[str]
links: Dict[int, str]
fetched_at: datetime
@dataclass
class BrowserStateData:
page_stack: List[str] = field(default_factory=list)
view_tokens: int = 1024
url_to_page: Dict[str, Page] = field(default_factory=dict)
@dataclass
class WebSearchResult:
title: str
url: str
content: Dict[str, str]
class SearchClient(Protocol):
def search(self, queries: List[str], max_results: Optional[int] = None): ...
class CrawlClient(Protocol):
def crawl(self, urls: List[str]): ...
# ---- Constants ---------------------------------------------------------------
DEFAULT_VIEW_TOKENS = 1024
CAPPED_TOOL_CONTENT_LEN = 8000
# ---- Helpers ----------------------------------------------------------------
def cap_tool_content(text: str) -> str:
if not text:
return text
if len(text) <= CAPPED_TOOL_CONTENT_LEN:
return text
if CAPPED_TOOL_CONTENT_LEN <= 1:
return text[:CAPPED_TOOL_CONTENT_LEN]
return text[: CAPPED_TOOL_CONTENT_LEN - 1] + ''
def _safe_domain(u: str) -> str:
try:
parsed = urlparse(u)
host = parsed.netloc or u
return host.replace('www.', '') if host else u
except Exception:
return u
# ---- BrowserState ------------------------------------------------------------
class BrowserState:
def __init__(self, initial_state: Optional[BrowserStateData] = None):
self._data = initial_state or BrowserStateData(view_tokens=DEFAULT_VIEW_TOKENS)
def get_data(self) -> BrowserStateData:
return self._data
def set_data(self, data: BrowserStateData) -> None:
self._data = data
# ---- Browser ----------------------------------------------------------------
class Browser:
def __init__(
self,
initial_state: Optional[BrowserStateData] = None,
client: Optional[Client] = None,
):
self.state = BrowserState(initial_state)
self._client: Optional[Client] = client
def set_client(self, client: Client) -> None:
self._client = client
def get_state(self) -> BrowserStateData:
return self.state.get_data()
# ---- internal utils ----
def _save_page(self, page: Page) -> None:
data = self.state.get_data()
data.url_to_page[page.url] = page
data.page_stack.append(page.url)
self.state.set_data(data)
def _page_from_stack(self, url: str) -> Page:
data = self.state.get_data()
page = data.url_to_page.get(url)
if not page:
raise ValueError(f'Page not found for url {url}')
return page
def _join_lines_with_numbers(self, lines: List[str]) -> str:
result = []
for i, line in enumerate(lines):
result.append(f'L{i}: {line}')
return '\n'.join(result)
def _wrap_lines(self, text: str, width: int = 80) -> List[str]:
if width <= 0:
width = 80
src_lines = text.split('\n')
wrapped: List[str] = []
for line in src_lines:
if line == '':
wrapped.append('')
elif len(line) <= width:
wrapped.append(line)
else:
words = re.split(r'\s+', line)
if not words:
wrapped.append(line)
continue
curr = ''
for w in words:
test = (curr + ' ' + w) if curr else w
if len(test) > width and curr:
wrapped.append(curr)
curr = w
else:
curr = test
if curr:
wrapped.append(curr)
return wrapped
def _process_markdown_links(self, text: str) -> Tuple[str, Dict[int, str]]:
links: Dict[int, str] = {}
link_id = 0
multiline_pattern = re.compile(r'\[([^\]]+)\]\s*\n\s*\(([^)]+)\)')
text = multiline_pattern.sub(lambda m: f'[{m.group(1)}]({m.group(2)})', text)
text = re.sub(r'\s+', ' ', text)
link_pattern = re.compile(r'\[([^\]]+)\]\(([^)]+)\)')
def _repl(m: re.Match) -> str:
nonlocal link_id
link_text = m.group(1).strip()
link_url = m.group(2).strip()
domain = _safe_domain(link_url)
formatted = f'{link_id}{link_text}{domain}'
links[link_id] = link_url
link_id += 1
return formatted
processed = link_pattern.sub(_repl, text)
return processed, links
def _get_end_loc(self, loc: int, num_lines: int, total_lines: int, lines: List[str]) -> int:
if num_lines <= 0:
txt = self._join_lines_with_numbers(lines[loc:])
data = self.state.get_data()
chars_per_token = 4
max_chars = min(data.view_tokens * chars_per_token, len(txt))
num_lines = txt[:max_chars].count('\n') + 1
return min(loc + num_lines, total_lines)
def _display_page(self, page: Page, cursor: int, loc: int, num_lines: int) -> str:
total_lines = len(page.lines) or 0
if total_lines == 0:
page.lines = ['']
total_lines = 1
if loc != loc or loc < 0:
loc = 0
elif loc >= total_lines:
loc = max(0, total_lines - 1)
end_loc = self._get_end_loc(loc, num_lines, total_lines, page.lines)
header = f'[{cursor}] {page.title}'
header += f'({page.url})\n' if page.url else '\n'
header += f'**viewing lines [{loc} - {end_loc - 1}] of {total_lines - 1}**\n\n'
body_lines = []
for i in range(loc, end_loc):
body_lines.append(f'L{i}: {page.lines[i]}')
return header + '\n'.join(body_lines)
# ---- page builders ----
def _build_search_results_page_collection(self, query: str, results: Dict[str, Any]) -> Page:
page = Page(
url=f'search_results_{query}',
title=query,
text='',
lines=[],
links={},
fetched_at=datetime.utcnow(),
)
tb = []
tb.append('')
tb.append('# Search Results')
tb.append('')
link_idx = 0
for query_results in results.get('results', {}).values():
for result in query_results:
domain = _safe_domain(result.get('url', ''))
link_fmt = f'* 【{link_idx}{result.get("title", "")}{domain}'
tb.append(link_fmt)
raw_snip = result.get('content') or ''
capped = (raw_snip[:400] + '') if len(raw_snip) > 400 else raw_snip
cleaned = re.sub(r'\d{40,}', lambda m: m.group(0)[:40] + '', capped)
cleaned = re.sub(r'\s{3,}', ' ', cleaned)
tb.append(cleaned)
page.links[link_idx] = result.get('url', '')
link_idx += 1
page.text = '\n'.join(tb)
page.lines = self._wrap_lines(page.text, 80)
return page
def _build_search_result_page(self, result: WebSearchResult, link_idx: int) -> Page:
page = Page(
url=result.url,
title=result.title,
text='',
lines=[],
links={},
fetched_at=datetime.utcnow(),
)
link_fmt = f'{link_idx}{result.title}\n'
preview = link_fmt + f'URL: {result.url}\n'
full_text = result.content.get('fullText', '') if result.content else ''
preview += full_text[:300] + '\n\n'
if not full_text:
page.links[link_idx] = result.url
if full_text:
raw = f'URL: {result.url}\n{full_text}'
processed, links = self._process_markdown_links(raw)
page.text = processed
page.links = links
else:
page.text = preview
page.lines = self._wrap_lines(page.text, 80)
return page
def _build_page_from_fetch(self, requested_url: str, fetch_response: Dict[str, Any]) -> Page:
page = Page(
url=requested_url,
title=requested_url,
text='',
lines=[],
links={},
fetched_at=datetime.utcnow(),
)
for url, url_results in fetch_response.get('results', {}).items():
if url_results:
r0 = url_results[0]
if r0.get('content'):
page.text = r0['content']
if r0.get('title'):
page.title = r0['title']
page.url = url
break
if not page.text:
page.text = 'No content could be extracted from this page.'
else:
page.text = f'URL: {page.url}\n{page.text}'
processed, links = self._process_markdown_links(page.text)
page.text = processed
page.links = links
page.lines = self._wrap_lines(page.text, 80)
return page
def _build_find_results_page(self, pattern: str, page: Page) -> Page:
find_page = Page(
url=f'find_results_{pattern}',
title=f'Find results for text: `{pattern}` in `{page.title}`',
text='',
lines=[],
links={},
fetched_at=datetime.utcnow(),
)
max_results = 50
num_show_lines = 4
pattern_lower = pattern.lower()
result_chunks: List[str] = []
line_idx = 0
while line_idx < len(page.lines):
line = page.lines[line_idx]
if pattern_lower not in line.lower():
line_idx += 1
continue
end_line = min(line_idx + num_show_lines, len(page.lines))
snippet = '\n'.join(page.lines[line_idx:end_line])
link_fmt = f'{len(result_chunks)}†match at L{line_idx}'
result_chunks.append(f'{link_fmt}\n{snippet}')
if len(result_chunks) >= max_results:
break
line_idx += num_show_lines
if not result_chunks:
find_page.text = f'No `find` results for pattern: `{pattern}`'
else:
find_page.text = '\n\n'.join(result_chunks)
find_page.lines = self._wrap_lines(find_page.text, 80)
return find_page
# ---- public API: search / open / find ------------------------------------
def search(self, *, query: str, topn: int = 5) -> Dict[str, Any]:
if not self._client:
raise RuntimeError('Client not provided')
resp = self._client.web_search(query, max_results=topn)
normalized: Dict[str, Any] = {'results': {}}
rows: List[Dict[str, str]] = []
for item in resp.results:
content = item.content or ''
rows.append(
{
'title': item.title,
'url': item.url,
'content': content,
}
)
normalized['results'][query] = rows
search_page = self._build_search_results_page_collection(query, normalized)
self._save_page(search_page)
cursor = len(self.get_state().page_stack) - 1
for query_results in normalized.get('results', {}).values():
for i, r in enumerate(query_results):
ws = WebSearchResult(
title=r.get('title', ''),
url=r.get('url', ''),
content={'fullText': r.get('content', '') or ''},
)
result_page = self._build_search_result_page(ws, i + 1)
data = self.get_state()
data.url_to_page[result_page.url] = result_page
self.state.set_data(data)
page_text = self._display_page(search_page, cursor, loc=0, num_lines=-1)
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
def open(
self,
*,
id: Optional[str | int] = None,
cursor: int = -1,
loc: int = 0,
num_lines: int = -1,
) -> Dict[str, Any]:
if not self._client:
raise RuntimeError('Client not provided')
state = self.get_state()
if isinstance(id, str):
url = id
if url in state.url_to_page:
self._save_page(state.url_to_page[url])
cursor = len(self.get_state().page_stack) - 1
page_text = self._display_page(state.url_to_page[url], cursor, loc, num_lines)
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
fetch_response = self._client.web_fetch(url)
normalized: Dict[str, Any] = {
'results': {
url: [
{
'title': fetch_response.title or url,
'url': url,
'content': fetch_response.content or '',
}
]
}
}
new_page = self._build_page_from_fetch(url, normalized)
self._save_page(new_page)
cursor = len(self.get_state().page_stack) - 1
page_text = self._display_page(new_page, cursor, loc, num_lines)
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
# Resolve current page from stack only if needed (int id or no id)
page: Optional[Page] = None
if cursor >= 0:
if state.page_stack:
if cursor >= len(state.page_stack):
cursor = max(0, len(state.page_stack) - 1)
page = self._page_from_stack(state.page_stack[cursor])
else:
page = None
else:
if state.page_stack:
page = self._page_from_stack(state.page_stack[-1])
if isinstance(id, int):
if not page:
raise RuntimeError('No current page to resolve link from')
link_url = page.links.get(id)
if not link_url:
err = Page(
url=f'invalid_link_{id}',
title=f'No link with id {id} on `{page.title}`',
text='',
lines=[],
links={},
fetched_at=datetime.utcnow(),
)
available = sorted(page.links.keys())
available_list = ', '.join(map(str, available)) if available else '(none)'
err.text = '\n'.join(
[
f'Requested link id: {id}',
f'Current page: {page.title}',
f'Available link ids on this page: {available_list}',
'',
'Tips:',
'- To scroll this page, call browser_open with { loc, num_lines } (no id).',
'- To open a result from a search results page, pass the correct { cursor, id }.',
]
)
err.lines = self._wrap_lines(err.text, 80)
self._save_page(err)
cursor = len(self.get_state().page_stack) - 1
page_text = self._display_page(err, cursor, 0, -1)
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
new_page = state.url_to_page.get(link_url)
if not new_page:
fetch_response = self._client.web_fetch(link_url)
normalized: Dict[str, Any] = {
'results': {
link_url: [
{
'title': fetch_response.title or link_url,
'url': link_url,
'content': fetch_response.content or '',
}
]
}
}
new_page = self._build_page_from_fetch(link_url, normalized)
self._save_page(new_page)
cursor = len(self.get_state().page_stack) - 1
page_text = self._display_page(new_page, cursor, loc, num_lines)
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
if not page:
raise RuntimeError('No current page to display')
cur = self.get_state()
cur.page_stack.append(page.url)
self.state.set_data(cur)
cursor = len(cur.page_stack) - 1
page_text = self._display_page(page, cursor, loc, num_lines)
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
def find(self, *, pattern: str, cursor: int = -1) -> Dict[str, Any]:
state = self.get_state()
if cursor == -1:
if not state.page_stack:
raise RuntimeError('No pages to search in')
page = self._page_from_stack(state.page_stack[-1])
cursor = len(state.page_stack) - 1
else:
if cursor < 0 or cursor >= len(state.page_stack):
cursor = max(0, min(cursor, len(state.page_stack) - 1))
page = self._page_from_stack(state.page_stack[cursor])
find_page = self._build_find_results_page(pattern, page)
self._save_page(find_page)
new_cursor = len(self.get_state().page_stack) - 1
page_text = self._display_page(find_page, new_cursor, 0, -1)
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
+70 -1
View File
@@ -1,3 +1,4 @@
import contextlib
import ipaddress
import json
import os
@@ -75,7 +76,7 @@ from ollama._types import (
T = TypeVar('T')
class BaseClient:
class BaseClient(contextlib.AbstractContextManager, contextlib.AbstractAsyncContextManager):
def __init__(
self,
client,
@@ -116,6 +117,12 @@ class BaseClient:
**kwargs,
)
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
CONNECTION_ERROR_MESSAGE = 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'
@@ -124,6 +131,9 @@ class Client(BaseClient):
def __init__(self, host: Optional[str] = None, **kwargs) -> None:
super().__init__(httpx.Client, host, **kwargs)
def close(self):
self._client.close()
def _request_raw(self, *args, **kwargs):
try:
r = self._client.request(*args, **kwargs)
@@ -200,11 +210,16 @@ class Client(BaseClient):
context: Optional[Sequence[int]] = None,
stream: Literal[False] = False,
think: Optional[bool] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
raw: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
steps: Optional[int] = None,
) -> GenerateResponse: ...
@overload
@@ -219,11 +234,16 @@ class Client(BaseClient):
context: Optional[Sequence[int]] = None,
stream: Literal[True] = True,
think: Optional[bool] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
raw: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
steps: Optional[int] = None,
) -> Iterator[GenerateResponse]: ...
def generate(
@@ -237,11 +257,16 @@ class Client(BaseClient):
context: Optional[Sequence[int]] = None,
stream: bool = False,
think: Optional[bool] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
raw: Optional[bool] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
steps: Optional[int] = None,
) -> Union[GenerateResponse, Iterator[GenerateResponse]]:
"""
Create a response using the requested model.
@@ -266,11 +291,16 @@ class Client(BaseClient):
context=context,
stream=stream,
think=think,
logprobs=logprobs,
top_logprobs=top_logprobs,
raw=raw,
format=format,
images=list(_copy_images(images)) if images else None,
options=options,
keep_alive=keep_alive,
width=width,
height=height,
steps=steps,
).model_dump(exclude_none=True),
stream=stream,
)
@@ -284,6 +314,8 @@ class Client(BaseClient):
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[False] = False,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
@@ -298,6 +330,8 @@ class Client(BaseClient):
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[True] = True,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
@@ -311,6 +345,8 @@ class Client(BaseClient):
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: bool = False,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
@@ -358,6 +394,8 @@ class Client(BaseClient):
tools=list(_copy_tools(tools)),
stream=stream,
think=think,
logprobs=logprobs,
top_logprobs=top_logprobs,
format=format,
options=options,
keep_alive=keep_alive,
@@ -686,6 +724,9 @@ class AsyncClient(BaseClient):
def __init__(self, host: Optional[str] = None, **kwargs) -> None:
super().__init__(httpx.AsyncClient, host, **kwargs)
async def close(self):
await self._client.aclose()
async def _request_raw(self, *args, **kwargs):
try:
r = await self._client.request(*args, **kwargs)
@@ -802,11 +843,16 @@ class AsyncClient(BaseClient):
context: Optional[Sequence[int]] = None,
stream: Literal[False] = False,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
raw: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
steps: Optional[int] = None,
) -> GenerateResponse: ...
@overload
@@ -821,11 +867,16 @@ class AsyncClient(BaseClient):
context: Optional[Sequence[int]] = None,
stream: Literal[True] = True,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
raw: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
steps: Optional[int] = None,
) -> AsyncIterator[GenerateResponse]: ...
async def generate(
@@ -839,11 +890,16 @@ class AsyncClient(BaseClient):
context: Optional[Sequence[int]] = None,
stream: bool = False,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
raw: Optional[bool] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
steps: Optional[int] = None,
) -> Union[GenerateResponse, AsyncIterator[GenerateResponse]]:
"""
Create a response using the requested model.
@@ -867,11 +923,16 @@ class AsyncClient(BaseClient):
context=context,
stream=stream,
think=think,
logprobs=logprobs,
top_logprobs=top_logprobs,
raw=raw,
format=format,
images=list(_copy_images(images)) if images else None,
options=options,
keep_alive=keep_alive,
width=width,
height=height,
steps=steps,
).model_dump(exclude_none=True),
stream=stream,
)
@@ -885,6 +946,8 @@ class AsyncClient(BaseClient):
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[False] = False,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
@@ -899,6 +962,8 @@ class AsyncClient(BaseClient):
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[True] = True,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
@@ -912,6 +977,8 @@ class AsyncClient(BaseClient):
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: bool = False,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
@@ -960,6 +1027,8 @@ class AsyncClient(BaseClient):
tools=list(_copy_tools(tools)),
stream=stream,
think=think,
logprobs=logprobs,
top_logprobs=top_logprobs,
format=format,
options=options,
keep_alive=keep_alive,
+53 -1
View File
@@ -210,6 +210,22 @@ class GenerateRequest(BaseGenerateRequest):
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None
'Enable thinking mode (for thinking models).'
logprobs: Optional[bool] = None
'Return log probabilities for generated tokens.'
top_logprobs: Optional[int] = None
'Number of alternative tokens and log probabilities to include per position (0-20).'
# Experimental image generation parameters
width: Optional[int] = None
'Width of the generated image in pixels (for image generation models).'
height: Optional[int] = None
'Height of the generated image in pixels (for image generation models).'
steps: Optional[int] = None
'Number of diffusion steps (for image generation models).'
class BaseGenerateResponse(SubscriptableBaseModel):
model: Optional[str] = None
@@ -243,12 +259,25 @@ class BaseGenerateResponse(SubscriptableBaseModel):
'Duration of evaluating inference in nanoseconds.'
class TokenLogprob(SubscriptableBaseModel):
token: str
'Token text.'
logprob: float
'Log probability for the token.'
class Logprob(TokenLogprob):
top_logprobs: Optional[Sequence[TokenLogprob]] = None
'Most likely tokens and their log probabilities.'
class GenerateResponse(BaseGenerateResponse):
"""
Response returned by generate requests.
"""
response: str
response: Optional[str] = None
'Response content. When streaming, this contains a fragment of the response.'
thinking: Optional[str] = None
@@ -257,6 +286,20 @@ class GenerateResponse(BaseGenerateResponse):
context: Optional[Sequence[int]] = None
'Tokenized history up to the point of the response.'
logprobs: Optional[Sequence[Logprob]] = None
'Log probabilities for generated tokens.'
# Image generation response fields
image: Optional[str] = None
'Base64-encoded generated image data (for image generation models).'
# Streaming progress fields (for image generation)
completed: Optional[int] = None
'Number of completed steps (for image generation streaming).'
total: Optional[int] = None
'Total number of steps (for image generation streaming).'
class Message(SubscriptableBaseModel):
"""
@@ -360,6 +403,12 @@ class ChatRequest(BaseGenerateRequest):
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None
'Enable thinking mode (for thinking models).'
logprobs: Optional[bool] = None
'Return log probabilities for generated tokens.'
top_logprobs: Optional[int] = None
'Number of alternative tokens and log probabilities to include per position (0-20).'
class ChatResponse(BaseGenerateResponse):
"""
@@ -369,6 +418,9 @@ class ChatResponse(BaseGenerateResponse):
message: Message
'Response message.'
logprobs: Optional[Sequence[Logprob]] = None
'Log probabilities for generated tokens if requested.'
class EmbedRequest(BaseRequest):
input: Union[str, Sequence[str]]
+1 -1
View File
@@ -37,7 +37,7 @@ dependencies = [ 'ruff>=0.9.1' ]
config-path = 'none'
[tool.ruff]
line-length = 999
line-length = 320
indent-width = 2
[tool.ruff.format]
+211
View File
@@ -61,6 +61,44 @@ def test_client_chat(httpserver: HTTPServer):
assert response['message']['content'] == "I don't know."
def test_client_chat_with_logprobs(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/chat',
method='POST',
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Hi'}],
'tools': [],
'stream': False,
'logprobs': True,
'top_logprobs': 3,
},
).respond_with_json(
{
'model': 'dummy',
'message': {
'role': 'assistant',
'content': 'Hello',
},
'logprobs': [
{
'token': 'Hello',
'logprob': -0.1,
'top_logprobs': [
{'token': 'Hello', 'logprob': -0.1},
{'token': 'Hi', 'logprob': -1.0},
],
}
],
}
)
client = Client(httpserver.url_for('/'))
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Hi'}], logprobs=True, top_logprobs=3)
assert response['logprobs'][0]['token'] == 'Hello'
assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi'
def test_client_chat_stream(httpserver: HTTPServer):
def stream_handler(_: Request):
def generate():
@@ -294,6 +332,40 @@ def test_client_generate(httpserver: HTTPServer):
assert response['response'] == 'Because it is.'
def test_client_generate_with_logprobs(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy',
'prompt': 'Why',
'stream': False,
'logprobs': True,
'top_logprobs': 2,
},
).respond_with_json(
{
'model': 'dummy',
'response': 'Hello',
'logprobs': [
{
'token': 'Hello',
'logprob': -0.2,
'top_logprobs': [
{'token': 'Hello', 'logprob': -0.2},
{'token': 'Hi', 'logprob': -1.5},
],
}
],
}
)
client = Client(httpserver.url_for('/'))
response = client.generate('dummy', 'Why', logprobs=True, top_logprobs=2)
assert response['logprobs'][0]['token'] == 'Hello'
assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi'
def test_client_generate_with_image_type(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
@@ -496,6 +568,115 @@ async def test_async_client_generate_format_pydantic(httpserver: HTTPServer):
assert response['response'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
def test_client_generate_image(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy-image',
'prompt': 'a sunset over mountains',
'stream': False,
'width': 1024,
'height': 768,
'steps': 20,
},
).respond_with_json(
{
'model': 'dummy-image',
'image': PNG_BASE64,
'done': True,
'done_reason': 'stop',
}
)
client = Client(httpserver.url_for('/'))
response = client.generate('dummy-image', 'a sunset over mountains', width=1024, height=768, steps=20)
assert response['model'] == 'dummy-image'
assert response['image'] == PNG_BASE64
assert response['done'] is True
def test_client_generate_image_stream(httpserver: HTTPServer):
def stream_handler(_: Request):
def generate():
# Progress updates
for i in range(1, 4):
yield (
json.dumps(
{
'model': 'dummy-image',
'completed': i,
'total': 3,
'done': False,
}
)
+ '\n'
)
# Final response with image
yield (
json.dumps(
{
'model': 'dummy-image',
'image': PNG_BASE64,
'done': True,
'done_reason': 'stop',
}
)
+ '\n'
)
return Response(generate())
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy-image',
'prompt': 'a sunset over mountains',
'stream': True,
'width': 512,
'height': 512,
},
).respond_with_handler(stream_handler)
client = Client(httpserver.url_for('/'))
response = client.generate('dummy-image', 'a sunset over mountains', stream=True, width=512, height=512)
parts = list(response)
# Check progress updates
assert parts[0]['completed'] == 1
assert parts[0]['total'] == 3
assert parts[0]['done'] is False
# Check final response
assert parts[-1]['image'] == PNG_BASE64
assert parts[-1]['done'] is True
async def test_async_client_generate_image(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy-image',
'prompt': 'a robot painting',
'stream': False,
'width': 1024,
'height': 1024,
},
).respond_with_json(
{
'model': 'dummy-image',
'image': PNG_BASE64,
'done': True,
}
)
client = AsyncClient(httpserver.url_for('/'))
response = await client.generate('dummy-image', 'a robot painting', width=1024, height=1024)
assert response['model'] == 'dummy-image'
assert response['image'] == PNG_BASE64
def test_client_pull(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/pull',
@@ -1275,3 +1456,33 @@ def test_client_explicit_bearer_header_overrides_env(monkeypatch: pytest.MonkeyP
client = Client(headers={'Authorization': 'Bearer explicit-token'})
assert client._client.headers['authorization'] == 'Bearer explicit-token'
client.web_search('override check')
def test_client_close():
client = Client()
client.close()
assert client._client.is_closed
@pytest.mark.anyio
async def test_async_client_close():
client = AsyncClient()
await client.close()
assert client._client.is_closed
def test_client_context_manager():
with Client() as client:
assert isinstance(client, Client)
assert not client._client.is_closed
assert client._client.is_closed
@pytest.mark.anyio
async def test_async_client_context_manager():
async with AsyncClient() as client:
assert isinstance(client, AsyncClient)
assert not client._client.is_closed
assert client._client.is_closed