mirror of
https://github.com/ollama/ollama-python.git
synced 2026-01-14 06:07:17 +08:00
Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60e7b2f9ce | ||
|
|
d1d704050b | ||
|
|
115792583e | ||
|
|
0008226fda | ||
|
|
9ddd5f0182 | ||
|
|
d967f048d9 | ||
|
|
ab49a669cd | ||
|
|
16f344f635 | ||
|
|
d0f71bc8b8 | ||
|
|
b22c5fdabb | ||
|
|
4d0b81b37a | ||
|
|
a1d04f04f2 | ||
|
|
8af6cac86b | ||
|
|
9f41447f20 | ||
|
|
da79e987f0 | ||
|
|
c8392d6524 | ||
|
|
07ab287cdf | ||
|
|
b0f6b99ca6 | ||
|
|
c87604c66f | ||
|
|
53ff3cd025 | ||
|
|
aa4b476f26 |
4
.github/workflows/publish.yaml
vendored
4
.github/workflows/publish.yaml
vendored
@ -13,8 +13,8 @@ jobs:
|
||||
id-token: write
|
||||
contents: write
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v6
|
||||
- uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
6
.github/workflows/test.yaml
vendored
6
.github/workflows/test.yaml
vendored
@ -10,7 +10,7 @@ jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
- uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
@ -19,8 +19,8 @@ jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v6
|
||||
- uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
76
README.md
76
README.md
@ -50,6 +50,82 @@ for chunk in stream:
|
||||
print(chunk['message']['content'], end='', flush=True)
|
||||
```
|
||||
|
||||
## Cloud Models
|
||||
|
||||
Run larger models by offloading to Ollama’s cloud while keeping your local workflow.
|
||||
|
||||
- Supported models: `deepseek-v3.1:671b-cloud`, `gpt-oss:20b-cloud`, `gpt-oss:120b-cloud`, `kimi-k2:1t-cloud`, `qwen3-coder:480b-cloud`, `kimi-k2-thinking` See [Ollama Models - Cloud](https://ollama.com/search?c=cloud) for more information
|
||||
|
||||
### Run via local Ollama
|
||||
|
||||
1) Sign in (one-time):
|
||||
|
||||
```
|
||||
ollama signin
|
||||
```
|
||||
|
||||
2) Pull a cloud model:
|
||||
|
||||
```
|
||||
ollama pull gpt-oss:120b-cloud
|
||||
```
|
||||
|
||||
3) Make a request:
|
||||
|
||||
```python
|
||||
from ollama import Client
|
||||
|
||||
client = Client()
|
||||
|
||||
messages = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Why is the sky blue?',
|
||||
},
|
||||
]
|
||||
|
||||
for part in client.chat('gpt-oss:120b-cloud', messages=messages, stream=True):
|
||||
print(part.message.content, end='', flush=True)
|
||||
```
|
||||
|
||||
### Cloud API (ollama.com)
|
||||
|
||||
Access cloud models directly by pointing the client at `https://ollama.com`.
|
||||
|
||||
1) Create an API key from [ollama.com](https://ollama.com/settings/keys) , then set:
|
||||
|
||||
```
|
||||
export OLLAMA_API_KEY=your_api_key
|
||||
```
|
||||
|
||||
2) (Optional) List models available via the API:
|
||||
|
||||
```
|
||||
curl https://ollama.com/api/tags
|
||||
```
|
||||
|
||||
3) Generate a response via the cloud API:
|
||||
|
||||
```python
|
||||
import os
|
||||
from ollama import Client
|
||||
|
||||
client = Client(
|
||||
host='https://ollama.com',
|
||||
headers={'Authorization': 'Bearer ' + os.environ.get('OLLAMA_API_KEY')}
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Why is the sky blue?',
|
||||
},
|
||||
]
|
||||
|
||||
for part in client.chat('gpt-oss:120b', messages=messages, stream=True):
|
||||
print(part.message.content, end='', flush=True)
|
||||
```
|
||||
|
||||
## Custom client
|
||||
A custom client can be created by instantiating `Client` or `AsyncClient` from `ollama`.
|
||||
|
||||
|
||||
@ -1,75 +1,123 @@
|
||||
# Running Examples
|
||||
|
||||
Run the examples in this directory with:
|
||||
|
||||
```sh
|
||||
# Run example
|
||||
python3 examples/<example>.py
|
||||
|
||||
# or with uv
|
||||
uv run examples/<example>.py
|
||||
```
|
||||
|
||||
See [ollama/docs/api.md](https://github.com/ollama/ollama/blob/main/docs/api.md) for full API documentation
|
||||
|
||||
### Chat - Chat with a model
|
||||
|
||||
- [chat.py](chat.py)
|
||||
- [async-chat.py](async-chat.py)
|
||||
- [chat-stream.py](chat-stream.py) - Streamed outputs
|
||||
- [chat-with-history.py](chat-with-history.py) - Chat with model and maintain history of the conversation
|
||||
|
||||
|
||||
### Generate - Generate text with a model
|
||||
|
||||
- [generate.py](generate.py)
|
||||
- [async-generate.py](async-generate.py)
|
||||
- [generate-stream.py](generate-stream.py) - Streamed outputs
|
||||
- [fill-in-middle.py](fill-in-middle.py) - Given a prefix and suffix, fill in the middle
|
||||
|
||||
|
||||
### Tools/Function Calling - Call a function with a model
|
||||
|
||||
- [tools.py](tools.py) - Simple example of Tools/Function Calling
|
||||
- [async-tools.py](async-tools.py)
|
||||
- [multi-tool.py](multi-tool.py) - Using multiple tools, with thinking enabled
|
||||
|
||||
#### gpt-oss
|
||||
- [gpt-oss-tools.py](gpt-oss-tools.py) - Using tools with gpt-oss
|
||||
- [gpt-oss-tools-stream.py](gpt-oss-tools-stream.py) - Using tools with gpt-oss, with streaming enabled
|
||||
#### gpt-oss
|
||||
|
||||
- [gpt-oss-tools.py](gpt-oss-tools.py)
|
||||
- [gpt-oss-tools-stream.py](gpt-oss-tools-stream.py)
|
||||
|
||||
### Web search
|
||||
|
||||
An API key from Ollama's cloud service is required. You can create one [here](https://ollama.com/settings/keys).
|
||||
|
||||
```shell
|
||||
export OLLAMA_API_KEY="your_api_key_here"
|
||||
```
|
||||
|
||||
- [web-search.py](web-search.py)
|
||||
- [web-search-gpt-oss.py](web-search-gpt-oss.py) - Using browser research tools with gpt-oss
|
||||
|
||||
#### MCP server
|
||||
|
||||
The MCP server can be used with an MCP client like Cursor, Cline, Codex, Open WebUI, Goose, and more.
|
||||
|
||||
```sh
|
||||
uv run examples/web-search-mcp.py
|
||||
```
|
||||
|
||||
Configuration to use with an MCP client:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"web_search": {
|
||||
"type": "stdio",
|
||||
"command": "uv",
|
||||
"args": ["run", "path/to/ollama-python/examples/web-search-mcp.py"],
|
||||
"env": { "OLLAMA_API_KEY": "your_api_key_here" }
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- [web-search-mcp.py](web-search-mcp.py)
|
||||
|
||||
### Multimodal with Images - Chat with a multimodal (image chat) model
|
||||
|
||||
- [multimodal-chat.py](multimodal-chat.py)
|
||||
- [multimodal-generate.py](multimodal-generate.py)
|
||||
|
||||
|
||||
### Structured Outputs - Generate structured outputs with a model
|
||||
|
||||
- [structured-outputs.py](structured-outputs.py)
|
||||
- [async-structured-outputs.py](async-structured-outputs.py)
|
||||
- [structured-outputs-image.py](structured-outputs-image.py)
|
||||
|
||||
|
||||
### Ollama List - List all downloaded models and their properties
|
||||
|
||||
- [list.py](list.py)
|
||||
|
||||
|
||||
### Ollama Show - Display model properties and capabilities
|
||||
|
||||
- [show.py](show.py)
|
||||
|
||||
|
||||
### Ollama ps - Show model status with CPU/GPU usage
|
||||
|
||||
- [ps.py](ps.py)
|
||||
|
||||
|
||||
### Ollama Pull - Pull a model from Ollama
|
||||
Requirement: `pip install tqdm`
|
||||
- [pull.py](pull.py)
|
||||
|
||||
Requirement: `pip install tqdm`
|
||||
|
||||
- [pull.py](pull.py)
|
||||
|
||||
### Ollama Create - Create a model from a Modelfile
|
||||
- [create.py](create.py)
|
||||
|
||||
- [create.py](create.py)
|
||||
|
||||
### Ollama Embed - Generate embeddings with a model
|
||||
|
||||
- [embed.py](embed.py)
|
||||
|
||||
|
||||
### Thinking - Enable thinking mode for a model
|
||||
|
||||
- [thinking.py](thinking.py)
|
||||
|
||||
### Thinking (generate) - Enable thinking mode for a model
|
||||
|
||||
- [thinking-generate.py](thinking-generate.py)
|
||||
|
||||
### Thinking (levels) - Choose the thinking level
|
||||
|
||||
- [thinking-levels.py](thinking-levels.py)
|
||||
|
||||
31
examples/chat-logprobs.py
Normal file
31
examples/chat-logprobs.py
Normal file
@ -0,0 +1,31 @@
|
||||
from typing import Iterable
|
||||
|
||||
import ollama
|
||||
|
||||
|
||||
def print_logprobs(logprobs: Iterable[dict], label: str) -> None:
|
||||
print(f'\n{label}:')
|
||||
for entry in logprobs:
|
||||
token = entry.get('token', '')
|
||||
logprob = entry.get('logprob')
|
||||
print(f' token={token!r:<12} logprob={logprob:.3f}')
|
||||
for alt in entry.get('top_logprobs', []):
|
||||
if alt['token'] != token:
|
||||
print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})')
|
||||
|
||||
|
||||
messages = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'hi! be concise.',
|
||||
},
|
||||
]
|
||||
|
||||
response = ollama.chat(
|
||||
model='gemma3',
|
||||
messages=messages,
|
||||
logprobs=True,
|
||||
top_logprobs=3,
|
||||
)
|
||||
print('Chat response:', response['message']['content'])
|
||||
print_logprobs(response.get('logprobs', []), 'chat logprobs')
|
||||
@ -15,7 +15,8 @@ messages = [
|
||||
},
|
||||
{
|
||||
'role': 'assistant',
|
||||
'content': 'The weather in Tokyo is typically warm and humid during the summer months, with temperatures often exceeding 30°C (86°F). The city experiences a rainy season from June to September, with heavy rainfall and occasional typhoons. Winter is mild, with temperatures rarely dropping below freezing. The city is known for its high-tech and vibrant culture, with many popular tourist attractions such as the Tokyo Tower, Senso-ji Temple, and the bustling Shibuya district.',
|
||||
'content': """The weather in Tokyo is typically warm and humid during the summer months, with temperatures often exceeding 30°C (86°F). The city experiences a rainy season from June to September, with heavy rainfall and occasional typhoons. Winter is mild, with temperatures
|
||||
rarely dropping below freezing. The city is known for its high-tech and vibrant culture, with many popular tourist attractions such as the Tokyo Tower, Senso-ji Temple, and the bustling Shibuya district.""",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
24
examples/generate-logprobs.py
Normal file
24
examples/generate-logprobs.py
Normal file
@ -0,0 +1,24 @@
|
||||
from typing import Iterable
|
||||
|
||||
import ollama
|
||||
|
||||
|
||||
def print_logprobs(logprobs: Iterable[dict], label: str) -> None:
|
||||
print(f'\n{label}:')
|
||||
for entry in logprobs:
|
||||
token = entry.get('token', '')
|
||||
logprob = entry.get('logprob')
|
||||
print(f' token={token!r:<12} logprob={logprob:.3f}')
|
||||
for alt in entry.get('top_logprobs', []):
|
||||
if alt['token'] != token:
|
||||
print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})')
|
||||
|
||||
|
||||
response = ollama.generate(
|
||||
model='gemma3',
|
||||
prompt='hi! be concise.',
|
||||
logprobs=True,
|
||||
top_logprobs=3,
|
||||
)
|
||||
print('Generate response:', response['response'])
|
||||
print_logprobs(response.get('logprobs', []), 'generate logprobs')
|
||||
@ -1,7 +1,17 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.11"
|
||||
# dependencies = [
|
||||
# "gpt-oss",
|
||||
# "ollama",
|
||||
# "rich",
|
||||
# ]
|
||||
# ///
|
||||
import random
|
||||
from typing import Iterator
|
||||
|
||||
from ollama import chat
|
||||
from rich import print
|
||||
|
||||
from ollama import Client
|
||||
from ollama._types import ChatResponse
|
||||
|
||||
|
||||
@ -40,37 +50,55 @@ available_tools = {'get_weather': get_weather, 'get_weather_conditions': get_wea
|
||||
|
||||
messages = [{'role': 'user', 'content': 'What is the weather like in London? What are the conditions in Toronto?'}]
|
||||
|
||||
client = Client(
|
||||
# Ollama Turbo
|
||||
# host="https://ollama.com", headers={'Authorization': (os.getenv('OLLAMA_API_KEY'))}
|
||||
)
|
||||
|
||||
model = 'gpt-oss:20b'
|
||||
# gpt-oss can call tools while "thinking"
|
||||
# a loop is needed to call the tools and get the results
|
||||
final = True
|
||||
while True:
|
||||
response_stream: Iterator[ChatResponse] = chat(model=model, messages=messages, tools=[get_weather, get_weather_conditions], stream=True)
|
||||
response_stream: Iterator[ChatResponse] = client.chat(model=model, messages=messages, tools=[get_weather, get_weather_conditions], stream=True)
|
||||
tool_calls = []
|
||||
thinking = ''
|
||||
content = ''
|
||||
|
||||
for chunk in response_stream:
|
||||
if chunk.message.tool_calls:
|
||||
tool_calls.extend(chunk.message.tool_calls)
|
||||
|
||||
if chunk.message.content:
|
||||
if not (chunk.message.thinking or chunk.message.thinking == '') and final:
|
||||
print('\nFinal result: ')
|
||||
print('\n\n' + '=' * 10)
|
||||
print('Final result: ')
|
||||
final = False
|
||||
print(chunk.message.content, end='', flush=True)
|
||||
|
||||
if chunk.message.thinking:
|
||||
# accumulate thinking
|
||||
thinking += chunk.message.thinking
|
||||
print(chunk.message.thinking, end='', flush=True)
|
||||
|
||||
if thinking != '' or content != '' or len(tool_calls) > 0:
|
||||
messages.append({'role': 'assistant', 'thinking': thinking, 'content': content, 'tool_calls': tool_calls})
|
||||
|
||||
print()
|
||||
|
||||
if chunk.message.tool_calls:
|
||||
for tool_call in chunk.message.tool_calls:
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
function_to_call = available_tools.get(tool_call.function.name)
|
||||
if function_to_call:
|
||||
print('\nCalling tool: ', tool_call.function.name, 'with arguments: ', tool_call.function.arguments)
|
||||
print('\nCalling tool:', tool_call.function.name, 'with arguments: ', tool_call.function.arguments)
|
||||
result = function_to_call(**tool_call.function.arguments)
|
||||
print('Tool result: ', result + '\n')
|
||||
|
||||
messages.append(chunk.message)
|
||||
messages.append({'role': 'tool', 'content': result, 'tool_name': tool_call.function.name})
|
||||
result_message = {'role': 'tool', 'content': result, 'tool_name': tool_call.function.name}
|
||||
messages.append(result_message)
|
||||
else:
|
||||
print(f'Tool {tool_call.function.name} not found')
|
||||
messages.append({'role': 'tool', 'content': f'Tool {tool_call.function.name} not found', 'tool_name': tool_call.function.name})
|
||||
|
||||
else:
|
||||
# no more tool calls, we can stop the loop
|
||||
|
||||
@ -1,6 +1,16 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.11"
|
||||
# dependencies = [
|
||||
# "gpt-oss",
|
||||
# "ollama",
|
||||
# "rich",
|
||||
# ]
|
||||
# ///
|
||||
import random
|
||||
|
||||
from ollama import chat
|
||||
from rich import print
|
||||
|
||||
from ollama import Client
|
||||
from ollama._types import ChatResponse
|
||||
|
||||
|
||||
@ -40,11 +50,15 @@ available_tools = {'get_weather': get_weather, 'get_weather_conditions': get_wea
|
||||
messages = [{'role': 'user', 'content': 'What is the weather like in London? What are the conditions in Toronto?'}]
|
||||
|
||||
|
||||
client = Client(
|
||||
# Ollama Turbo
|
||||
# host="https://ollama.com", headers={'Authorization': (os.getenv('OLLAMA_API_KEY'))}
|
||||
)
|
||||
model = 'gpt-oss:20b'
|
||||
# gpt-oss can call tools while "thinking"
|
||||
# a loop is needed to call the tools and get the results
|
||||
while True:
|
||||
response: ChatResponse = chat(model=model, messages=messages, tools=[get_weather, get_weather_conditions])
|
||||
response: ChatResponse = client.chat(model=model, messages=messages, tools=[get_weather, get_weather_conditions])
|
||||
|
||||
if response.message.content:
|
||||
print('Content: ')
|
||||
@ -53,18 +67,18 @@ while True:
|
||||
print('Thinking: ')
|
||||
print(response.message.thinking + '\n')
|
||||
|
||||
messages.append(response.message)
|
||||
|
||||
if response.message.tool_calls:
|
||||
for tool_call in response.message.tool_calls:
|
||||
function_to_call = available_tools.get(tool_call.function.name)
|
||||
if function_to_call:
|
||||
result = function_to_call(**tool_call.function.arguments)
|
||||
print('Result from tool call name: ', tool_call.function.name, 'with arguments: ', tool_call.function.arguments, 'result: ', result + '\n')
|
||||
|
||||
messages.append(response.message)
|
||||
messages.append({'role': 'tool', 'content': result, 'tool_name': tool_call.function.name})
|
||||
else:
|
||||
print(f'Tool {tool_call.function.name} not found')
|
||||
|
||||
messages.append({'role': 'tool', 'content': f'Tool {tool_call.function.name} not found', 'tool_name': tool_call.function.name})
|
||||
else:
|
||||
# no more tool calls, we can stop the loop
|
||||
break
|
||||
|
||||
26
examples/thinking-levels.py
Normal file
26
examples/thinking-levels.py
Normal file
@ -0,0 +1,26 @@
|
||||
from ollama import chat
|
||||
|
||||
|
||||
def heading(text):
|
||||
print(text)
|
||||
print('=' * len(text))
|
||||
|
||||
|
||||
messages = [
|
||||
{'role': 'user', 'content': 'What is 10 + 23?'},
|
||||
]
|
||||
|
||||
# gpt-oss supports 'low', 'medium', 'high'
|
||||
levels = ['low', 'medium', 'high']
|
||||
for i, level in enumerate(levels):
|
||||
response = chat('gpt-oss:20b', messages=messages, think=level)
|
||||
|
||||
heading(f'Thinking ({level})')
|
||||
print(response.message.thinking)
|
||||
print('\n')
|
||||
heading('Response')
|
||||
print(response.message.content)
|
||||
print('\n')
|
||||
if i < len(levels) - 1:
|
||||
print('-' * 20)
|
||||
print('\n')
|
||||
99
examples/web-search-gpt-oss.py
Normal file
99
examples/web-search-gpt-oss.py
Normal file
@ -0,0 +1,99 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.11"
|
||||
# dependencies = [
|
||||
# "ollama",
|
||||
# ]
|
||||
# ///
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from web_search_gpt_oss_helper import Browser
|
||||
|
||||
from ollama import Client
|
||||
|
||||
|
||||
def main() -> None:
|
||||
client = Client()
|
||||
browser = Browser(initial_state=None, client=client)
|
||||
|
||||
def browser_search(query: str, topn: int = 10) -> str:
|
||||
return browser.search(query=query, topn=topn)['pageText']
|
||||
|
||||
def browser_open(id: int | str | None = None, cursor: int = -1, loc: int = -1, num_lines: int = -1) -> str:
|
||||
return browser.open(id=id, cursor=cursor, loc=loc, num_lines=num_lines)['pageText']
|
||||
|
||||
def browser_find(pattern: str, cursor: int = -1, **_: Any) -> str:
|
||||
return browser.find(pattern=pattern, cursor=cursor)['pageText']
|
||||
|
||||
browser_search_schema = {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'browser.search',
|
||||
},
|
||||
}
|
||||
|
||||
browser_open_schema = {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'browser.open',
|
||||
},
|
||||
}
|
||||
|
||||
browser_find_schema = {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'browser.find',
|
||||
},
|
||||
}
|
||||
|
||||
available_tools = {
|
||||
'browser.search': browser_search,
|
||||
'browser.open': browser_open,
|
||||
'browser.find': browser_find,
|
||||
}
|
||||
|
||||
query = "what is ollama's new engine"
|
||||
print('Prompt:', query, '\n')
|
||||
|
||||
messages: List[Dict[str, Any]] = [{'role': 'user', 'content': query}]
|
||||
|
||||
while True:
|
||||
resp = client.chat(
|
||||
model='gpt-oss:120b-cloud',
|
||||
messages=messages,
|
||||
tools=[browser_search_schema, browser_open_schema, browser_find_schema],
|
||||
think=True,
|
||||
)
|
||||
|
||||
if resp.message.thinking:
|
||||
print('Thinking:\n========\n')
|
||||
print(resp.message.thinking + '\n')
|
||||
|
||||
if resp.message.content:
|
||||
print('Response:\n========\n')
|
||||
print(resp.message.content + '\n')
|
||||
|
||||
messages.append(resp.message)
|
||||
|
||||
if not resp.message.tool_calls:
|
||||
break
|
||||
|
||||
for tc in resp.message.tool_calls:
|
||||
tool_name = tc.function.name
|
||||
args = tc.function.arguments or {}
|
||||
print(f'Tool name: {tool_name}, args: {args}')
|
||||
fn = available_tools.get(tool_name)
|
||||
if not fn:
|
||||
messages.append({'role': 'tool', 'content': f'Tool {tool_name} not found', 'tool_name': tool_name})
|
||||
continue
|
||||
|
||||
try:
|
||||
result_text = fn(**args)
|
||||
print('Result: ', result_text[:200] + '...')
|
||||
except Exception as e:
|
||||
result_text = f'Error from {tool_name}: {e}'
|
||||
|
||||
messages.append({'role': 'tool', 'content': result_text, 'tool_name': tool_name})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
116
examples/web-search-mcp.py
Normal file
116
examples/web-search-mcp.py
Normal file
@ -0,0 +1,116 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.11"
|
||||
# dependencies = [
|
||||
# "mcp",
|
||||
# "rich",
|
||||
# "ollama",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
MCP stdio server exposing Ollama web_search and web_fetch as tools.
|
||||
|
||||
Environment:
|
||||
- OLLAMA_API_KEY (required): if set, will be used as Authorization header.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
|
||||
from ollama import Client
|
||||
|
||||
try:
|
||||
# Preferred high-level API (if available)
|
||||
from mcp.server.fastmcp import FastMCP # type: ignore
|
||||
|
||||
_FASTMCP_AVAILABLE = True
|
||||
except Exception:
|
||||
_FASTMCP_AVAILABLE = False
|
||||
|
||||
if not _FASTMCP_AVAILABLE:
|
||||
# Fallback to the low-level stdio server API
|
||||
from mcp.server import Server # type: ignore
|
||||
from mcp.server.stdio import stdio_server # type: ignore
|
||||
|
||||
|
||||
client = Client()
|
||||
|
||||
|
||||
def _web_search_impl(query: str, max_results: int = 3) -> Dict[str, Any]:
|
||||
res = client.web_search(query=query, max_results=max_results)
|
||||
return res.model_dump()
|
||||
|
||||
|
||||
def _web_fetch_impl(url: str) -> Dict[str, Any]:
|
||||
res = client.web_fetch(url=url)
|
||||
return res.model_dump()
|
||||
|
||||
|
||||
if _FASTMCP_AVAILABLE:
|
||||
app = FastMCP('ollama-search-fetch')
|
||||
|
||||
@app.tool()
|
||||
def web_search(query: str, max_results: int = 3) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform a web search using Ollama's hosted search API.
|
||||
|
||||
Args:
|
||||
query: The search query to run.
|
||||
max_results: Maximum results to return (default: 3).
|
||||
|
||||
Returns:
|
||||
JSON-serializable dict matching ollama.WebSearchResponse.model_dump()
|
||||
"""
|
||||
|
||||
return _web_search_impl(query=query, max_results=max_results)
|
||||
|
||||
@app.tool()
|
||||
def web_fetch(url: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch the content of a web page for the provided URL.
|
||||
|
||||
Args:
|
||||
url: The absolute URL to fetch.
|
||||
|
||||
Returns:
|
||||
JSON-serializable dict matching ollama.WebFetchResponse.model_dump()
|
||||
"""
|
||||
|
||||
return _web_fetch_impl(url=url)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run()
|
||||
|
||||
else:
|
||||
server = Server('ollama-search-fetch') # type: ignore[name-defined]
|
||||
|
||||
@server.tool() # type: ignore[attr-defined]
|
||||
async def web_search(query: str, max_results: int = 3) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform a web search using Ollama's hosted search API.
|
||||
|
||||
Args:
|
||||
query: The search query to run.
|
||||
max_results: Maximum results to return (default: 3).
|
||||
"""
|
||||
|
||||
return await asyncio.to_thread(_web_search_impl, query, max_results)
|
||||
|
||||
@server.tool() # type: ignore[attr-defined]
|
||||
async def web_fetch(url: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch the content of a web page for the provided URL.
|
||||
|
||||
Args:
|
||||
url: The absolute URL to fetch.
|
||||
"""
|
||||
|
||||
return await asyncio.to_thread(_web_fetch_impl, url)
|
||||
|
||||
async def _main() -> None:
|
||||
async with stdio_server() as (read, write): # type: ignore[name-defined]
|
||||
await server.run(read, write) # type: ignore[attr-defined]
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(_main())
|
||||
85
examples/web-search.py
Normal file
85
examples/web-search.py
Normal file
@ -0,0 +1,85 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.11"
|
||||
# dependencies = [
|
||||
# "rich",
|
||||
# "ollama",
|
||||
# ]
|
||||
# ///
|
||||
from typing import Union
|
||||
|
||||
from rich import print
|
||||
|
||||
from ollama import WebFetchResponse, WebSearchResponse, chat, web_fetch, web_search
|
||||
|
||||
|
||||
def format_tool_results(
|
||||
results: Union[WebSearchResponse, WebFetchResponse],
|
||||
user_search: str,
|
||||
):
|
||||
output = []
|
||||
if isinstance(results, WebSearchResponse):
|
||||
output.append(f'Search results for "{user_search}":')
|
||||
for result in results.results:
|
||||
output.append(f'{result.title}' if result.title else f'{result.content}')
|
||||
output.append(f' URL: {result.url}')
|
||||
output.append(f' Content: {result.content}')
|
||||
output.append('')
|
||||
return '\n'.join(output).rstrip()
|
||||
|
||||
elif isinstance(results, WebFetchResponse):
|
||||
output.append(f'Fetch results for "{user_search}":')
|
||||
output.extend(
|
||||
[
|
||||
f'Title: {results.title}',
|
||||
f'URL: {user_search}' if user_search else '',
|
||||
f'Content: {results.content}',
|
||||
]
|
||||
)
|
||||
if results.links:
|
||||
output.append(f'Links: {", ".join(results.links)}')
|
||||
output.append('')
|
||||
return '\n'.join(output).rstrip()
|
||||
|
||||
|
||||
# client = Client(headers={'Authorization': f"Bearer {os.getenv('OLLAMA_API_KEY')}"} if api_key else None)
|
||||
available_tools = {'web_search': web_search, 'web_fetch': web_fetch}
|
||||
|
||||
query = "what is ollama's new engine"
|
||||
print('Query: ', query)
|
||||
|
||||
messages = [{'role': 'user', 'content': query}]
|
||||
while True:
|
||||
response = chat(model='qwen3', messages=messages, tools=[web_search, web_fetch], think=True)
|
||||
if response.message.thinking:
|
||||
print('Thinking: ')
|
||||
print(response.message.thinking + '\n\n')
|
||||
if response.message.content:
|
||||
print('Content: ')
|
||||
print(response.message.content + '\n')
|
||||
|
||||
messages.append(response.message)
|
||||
|
||||
if response.message.tool_calls:
|
||||
for tool_call in response.message.tool_calls:
|
||||
function_to_call = available_tools.get(tool_call.function.name)
|
||||
if function_to_call:
|
||||
args = tool_call.function.arguments
|
||||
result: Union[WebSearchResponse, WebFetchResponse] = function_to_call(**args)
|
||||
print('Result from tool call name:', tool_call.function.name, 'with arguments:')
|
||||
print(args)
|
||||
print()
|
||||
|
||||
user_search = args.get('query', '') or args.get('url', '')
|
||||
formatted_tool_results = format_tool_results(result, user_search=user_search)
|
||||
|
||||
print(formatted_tool_results[:300])
|
||||
print()
|
||||
|
||||
# caps the result at ~2000 tokens
|
||||
messages.append({'role': 'tool', 'content': formatted_tool_results[: 2000 * 4], 'tool_name': tool_call.function.name})
|
||||
else:
|
||||
print(f'Tool {tool_call.function.name} not found')
|
||||
messages.append({'role': 'tool', 'content': f'Tool {tool_call.function.name} not found', 'tool_name': tool_call.function.name})
|
||||
else:
|
||||
# no more tool calls, we can stop the loop
|
||||
break
|
||||
514
examples/web_search_gpt_oss_helper.py
Normal file
514
examples/web_search_gpt_oss_helper.py
Normal file
@ -0,0 +1,514 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Protocol, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ollama import Client
|
||||
|
||||
|
||||
@dataclass
|
||||
class Page:
|
||||
url: str
|
||||
title: str
|
||||
text: str
|
||||
lines: List[str]
|
||||
links: Dict[int, str]
|
||||
fetched_at: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrowserStateData:
|
||||
page_stack: List[str] = field(default_factory=list)
|
||||
view_tokens: int = 1024
|
||||
url_to_page: Dict[str, Page] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSearchResult:
|
||||
title: str
|
||||
url: str
|
||||
content: Dict[str, str]
|
||||
|
||||
|
||||
class SearchClient(Protocol):
|
||||
def search(self, queries: List[str], max_results: Optional[int] = None): ...
|
||||
|
||||
|
||||
class CrawlClient(Protocol):
|
||||
def crawl(self, urls: List[str]): ...
|
||||
|
||||
|
||||
# ---- Constants ---------------------------------------------------------------
|
||||
|
||||
DEFAULT_VIEW_TOKENS = 1024
|
||||
CAPPED_TOOL_CONTENT_LEN = 8000
|
||||
|
||||
# ---- Helpers ----------------------------------------------------------------
|
||||
|
||||
|
||||
def cap_tool_content(text: str) -> str:
|
||||
if not text:
|
||||
return text
|
||||
if len(text) <= CAPPED_TOOL_CONTENT_LEN:
|
||||
return text
|
||||
if CAPPED_TOOL_CONTENT_LEN <= 1:
|
||||
return text[:CAPPED_TOOL_CONTENT_LEN]
|
||||
return text[: CAPPED_TOOL_CONTENT_LEN - 1] + '…'
|
||||
|
||||
|
||||
def _safe_domain(u: str) -> str:
|
||||
try:
|
||||
parsed = urlparse(u)
|
||||
host = parsed.netloc or u
|
||||
return host.replace('www.', '') if host else u
|
||||
except Exception:
|
||||
return u
|
||||
|
||||
|
||||
# ---- BrowserState ------------------------------------------------------------
|
||||
|
||||
|
||||
class BrowserState:
|
||||
def __init__(self, initial_state: Optional[BrowserStateData] = None):
|
||||
self._data = initial_state or BrowserStateData(view_tokens=DEFAULT_VIEW_TOKENS)
|
||||
|
||||
def get_data(self) -> BrowserStateData:
|
||||
return self._data
|
||||
|
||||
def set_data(self, data: BrowserStateData) -> None:
|
||||
self._data = data
|
||||
|
||||
|
||||
# ---- Browser ----------------------------------------------------------------
|
||||
|
||||
|
||||
class Browser:
|
||||
def __init__(
|
||||
self,
|
||||
initial_state: Optional[BrowserStateData] = None,
|
||||
client: Optional[Client] = None,
|
||||
):
|
||||
self.state = BrowserState(initial_state)
|
||||
self._client: Optional[Client] = client
|
||||
|
||||
def set_client(self, client: Client) -> None:
|
||||
self._client = client
|
||||
|
||||
def get_state(self) -> BrowserStateData:
|
||||
return self.state.get_data()
|
||||
|
||||
# ---- internal utils ----
|
||||
|
||||
def _save_page(self, page: Page) -> None:
|
||||
data = self.state.get_data()
|
||||
data.url_to_page[page.url] = page
|
||||
data.page_stack.append(page.url)
|
||||
self.state.set_data(data)
|
||||
|
||||
def _page_from_stack(self, url: str) -> Page:
|
||||
data = self.state.get_data()
|
||||
page = data.url_to_page.get(url)
|
||||
if not page:
|
||||
raise ValueError(f'Page not found for url {url}')
|
||||
return page
|
||||
|
||||
def _join_lines_with_numbers(self, lines: List[str]) -> str:
|
||||
result = []
|
||||
for i, line in enumerate(lines):
|
||||
result.append(f'L{i}: {line}')
|
||||
return '\n'.join(result)
|
||||
|
||||
def _wrap_lines(self, text: str, width: int = 80) -> List[str]:
|
||||
if width <= 0:
|
||||
width = 80
|
||||
src_lines = text.split('\n')
|
||||
wrapped: List[str] = []
|
||||
for line in src_lines:
|
||||
if line == '':
|
||||
wrapped.append('')
|
||||
elif len(line) <= width:
|
||||
wrapped.append(line)
|
||||
else:
|
||||
words = re.split(r'\s+', line)
|
||||
if not words:
|
||||
wrapped.append(line)
|
||||
continue
|
||||
curr = ''
|
||||
for w in words:
|
||||
test = (curr + ' ' + w) if curr else w
|
||||
if len(test) > width and curr:
|
||||
wrapped.append(curr)
|
||||
curr = w
|
||||
else:
|
||||
curr = test
|
||||
if curr:
|
||||
wrapped.append(curr)
|
||||
return wrapped
|
||||
|
||||
def _process_markdown_links(self, text: str) -> Tuple[str, Dict[int, str]]:
|
||||
links: Dict[int, str] = {}
|
||||
link_id = 0
|
||||
|
||||
multiline_pattern = re.compile(r'\[([^\]]+)\]\s*\n\s*\(([^)]+)\)')
|
||||
text = multiline_pattern.sub(lambda m: f'[{m.group(1)}]({m.group(2)})', text)
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
|
||||
link_pattern = re.compile(r'\[([^\]]+)\]\(([^)]+)\)')
|
||||
|
||||
def _repl(m: re.Match) -> str:
|
||||
nonlocal link_id
|
||||
link_text = m.group(1).strip()
|
||||
link_url = m.group(2).strip()
|
||||
domain = _safe_domain(link_url)
|
||||
formatted = f'【{link_id}†{link_text}†{domain}】'
|
||||
links[link_id] = link_url
|
||||
link_id += 1
|
||||
return formatted
|
||||
|
||||
processed = link_pattern.sub(_repl, text)
|
||||
return processed, links
|
||||
|
||||
def _get_end_loc(self, loc: int, num_lines: int, total_lines: int, lines: List[str]) -> int:
|
||||
if num_lines <= 0:
|
||||
txt = self._join_lines_with_numbers(lines[loc:])
|
||||
data = self.state.get_data()
|
||||
chars_per_token = 4
|
||||
max_chars = min(data.view_tokens * chars_per_token, len(txt))
|
||||
num_lines = txt[:max_chars].count('\n') + 1
|
||||
return min(loc + num_lines, total_lines)
|
||||
|
||||
def _display_page(self, page: Page, cursor: int, loc: int, num_lines: int) -> str:
|
||||
total_lines = len(page.lines) or 0
|
||||
if total_lines == 0:
|
||||
page.lines = ['']
|
||||
total_lines = 1
|
||||
|
||||
if loc != loc or loc < 0:
|
||||
loc = 0
|
||||
elif loc >= total_lines:
|
||||
loc = max(0, total_lines - 1)
|
||||
|
||||
end_loc = self._get_end_loc(loc, num_lines, total_lines, page.lines)
|
||||
|
||||
header = f'[{cursor}] {page.title}'
|
||||
header += f'({page.url})\n' if page.url else '\n'
|
||||
header += f'**viewing lines [{loc} - {end_loc - 1}] of {total_lines - 1}**\n\n'
|
||||
|
||||
body_lines = []
|
||||
for i in range(loc, end_loc):
|
||||
body_lines.append(f'L{i}: {page.lines[i]}')
|
||||
|
||||
return header + '\n'.join(body_lines)
|
||||
|
||||
# ---- page builders ----
|
||||
|
||||
def _build_search_results_page_collection(self, query: str, results: Dict[str, Any]) -> Page:
|
||||
page = Page(
|
||||
url=f'search_results_{query}',
|
||||
title=query,
|
||||
text='',
|
||||
lines=[],
|
||||
links={},
|
||||
fetched_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
tb = []
|
||||
tb.append('')
|
||||
tb.append('# Search Results')
|
||||
tb.append('')
|
||||
|
||||
link_idx = 0
|
||||
for query_results in results.get('results', {}).values():
|
||||
for result in query_results:
|
||||
domain = _safe_domain(result.get('url', ''))
|
||||
link_fmt = f'* 【{link_idx}†{result.get("title", "")}†{domain}】'
|
||||
tb.append(link_fmt)
|
||||
|
||||
raw_snip = result.get('content') or ''
|
||||
capped = (raw_snip[:400] + '…') if len(raw_snip) > 400 else raw_snip
|
||||
cleaned = re.sub(r'\d{40,}', lambda m: m.group(0)[:40] + '…', capped)
|
||||
cleaned = re.sub(r'\s{3,}', ' ', cleaned)
|
||||
tb.append(cleaned)
|
||||
page.links[link_idx] = result.get('url', '')
|
||||
link_idx += 1
|
||||
|
||||
page.text = '\n'.join(tb)
|
||||
page.lines = self._wrap_lines(page.text, 80)
|
||||
return page
|
||||
|
||||
def _build_search_result_page(self, result: WebSearchResult, link_idx: int) -> Page:
|
||||
page = Page(
|
||||
url=result.url,
|
||||
title=result.title,
|
||||
text='',
|
||||
lines=[],
|
||||
links={},
|
||||
fetched_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
link_fmt = f'【{link_idx}†{result.title}】\n'
|
||||
preview = link_fmt + f'URL: {result.url}\n'
|
||||
full_text = result.content.get('fullText', '') if result.content else ''
|
||||
preview += full_text[:300] + '\n\n'
|
||||
|
||||
if not full_text:
|
||||
page.links[link_idx] = result.url
|
||||
|
||||
if full_text:
|
||||
raw = f'URL: {result.url}\n{full_text}'
|
||||
processed, links = self._process_markdown_links(raw)
|
||||
page.text = processed
|
||||
page.links = links
|
||||
else:
|
||||
page.text = preview
|
||||
|
||||
page.lines = self._wrap_lines(page.text, 80)
|
||||
return page
|
||||
|
||||
def _build_page_from_fetch(self, requested_url: str, fetch_response: Dict[str, Any]) -> Page:
|
||||
page = Page(
|
||||
url=requested_url,
|
||||
title=requested_url,
|
||||
text='',
|
||||
lines=[],
|
||||
links={},
|
||||
fetched_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
for url, url_results in fetch_response.get('results', {}).items():
|
||||
if url_results:
|
||||
r0 = url_results[0]
|
||||
if r0.get('content'):
|
||||
page.text = r0['content']
|
||||
if r0.get('title'):
|
||||
page.title = r0['title']
|
||||
page.url = url
|
||||
break
|
||||
|
||||
if not page.text:
|
||||
page.text = 'No content could be extracted from this page.'
|
||||
else:
|
||||
page.text = f'URL: {page.url}\n{page.text}'
|
||||
|
||||
processed, links = self._process_markdown_links(page.text)
|
||||
page.text = processed
|
||||
page.links = links
|
||||
page.lines = self._wrap_lines(page.text, 80)
|
||||
return page
|
||||
|
||||
def _build_find_results_page(self, pattern: str, page: Page) -> Page:
|
||||
find_page = Page(
|
||||
url=f'find_results_{pattern}',
|
||||
title=f'Find results for text: `{pattern}` in `{page.title}`',
|
||||
text='',
|
||||
lines=[],
|
||||
links={},
|
||||
fetched_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
max_results = 50
|
||||
num_show_lines = 4
|
||||
pattern_lower = pattern.lower()
|
||||
|
||||
result_chunks: List[str] = []
|
||||
line_idx = 0
|
||||
while line_idx < len(page.lines):
|
||||
line = page.lines[line_idx]
|
||||
if pattern_lower not in line.lower():
|
||||
line_idx += 1
|
||||
continue
|
||||
|
||||
end_line = min(line_idx + num_show_lines, len(page.lines))
|
||||
snippet = '\n'.join(page.lines[line_idx:end_line])
|
||||
link_fmt = f'【{len(result_chunks)}†match at L{line_idx}】'
|
||||
result_chunks.append(f'{link_fmt}\n{snippet}')
|
||||
|
||||
if len(result_chunks) >= max_results:
|
||||
break
|
||||
line_idx += num_show_lines
|
||||
|
||||
if not result_chunks:
|
||||
find_page.text = f'No `find` results for pattern: `{pattern}`'
|
||||
else:
|
||||
find_page.text = '\n\n'.join(result_chunks)
|
||||
|
||||
find_page.lines = self._wrap_lines(find_page.text, 80)
|
||||
return find_page
|
||||
|
||||
# ---- public API: search / open / find ------------------------------------
|
||||
|
||||
def search(self, *, query: str, topn: int = 5) -> Dict[str, Any]:
|
||||
if not self._client:
|
||||
raise RuntimeError('Client not provided')
|
||||
|
||||
resp = self._client.web_search(query, max_results=topn)
|
||||
|
||||
normalized: Dict[str, Any] = {'results': {}}
|
||||
rows: List[Dict[str, str]] = []
|
||||
for item in resp.results:
|
||||
content = item.content or ''
|
||||
rows.append(
|
||||
{
|
||||
'title': item.title,
|
||||
'url': item.url,
|
||||
'content': content,
|
||||
}
|
||||
)
|
||||
normalized['results'][query] = rows
|
||||
|
||||
search_page = self._build_search_results_page_collection(query, normalized)
|
||||
self._save_page(search_page)
|
||||
cursor = len(self.get_state().page_stack) - 1
|
||||
|
||||
for query_results in normalized.get('results', {}).values():
|
||||
for i, r in enumerate(query_results):
|
||||
ws = WebSearchResult(
|
||||
title=r.get('title', ''),
|
||||
url=r.get('url', ''),
|
||||
content={'fullText': r.get('content', '') or ''},
|
||||
)
|
||||
result_page = self._build_search_result_page(ws, i + 1)
|
||||
data = self.get_state()
|
||||
data.url_to_page[result_page.url] = result_page
|
||||
self.state.set_data(data)
|
||||
|
||||
page_text = self._display_page(search_page, cursor, loc=0, num_lines=-1)
|
||||
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
|
||||
|
||||
def open(
|
||||
self,
|
||||
*,
|
||||
id: Optional[str | int] = None,
|
||||
cursor: int = -1,
|
||||
loc: int = 0,
|
||||
num_lines: int = -1,
|
||||
) -> Dict[str, Any]:
|
||||
if not self._client:
|
||||
raise RuntimeError('Client not provided')
|
||||
|
||||
state = self.get_state()
|
||||
|
||||
if isinstance(id, str):
|
||||
url = id
|
||||
if url in state.url_to_page:
|
||||
self._save_page(state.url_to_page[url])
|
||||
cursor = len(self.get_state().page_stack) - 1
|
||||
page_text = self._display_page(state.url_to_page[url], cursor, loc, num_lines)
|
||||
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
|
||||
|
||||
fetch_response = self._client.web_fetch(url)
|
||||
normalized: Dict[str, Any] = {
|
||||
'results': {
|
||||
url: [
|
||||
{
|
||||
'title': fetch_response.title or url,
|
||||
'url': url,
|
||||
'content': fetch_response.content or '',
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
new_page = self._build_page_from_fetch(url, normalized)
|
||||
self._save_page(new_page)
|
||||
cursor = len(self.get_state().page_stack) - 1
|
||||
page_text = self._display_page(new_page, cursor, loc, num_lines)
|
||||
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
|
||||
|
||||
# Resolve current page from stack only if needed (int id or no id)
|
||||
page: Optional[Page] = None
|
||||
if cursor >= 0:
|
||||
if state.page_stack:
|
||||
if cursor >= len(state.page_stack):
|
||||
cursor = max(0, len(state.page_stack) - 1)
|
||||
page = self._page_from_stack(state.page_stack[cursor])
|
||||
else:
|
||||
page = None
|
||||
else:
|
||||
if state.page_stack:
|
||||
page = self._page_from_stack(state.page_stack[-1])
|
||||
|
||||
if isinstance(id, int):
|
||||
if not page:
|
||||
raise RuntimeError('No current page to resolve link from')
|
||||
|
||||
link_url = page.links.get(id)
|
||||
if not link_url:
|
||||
err = Page(
|
||||
url=f'invalid_link_{id}',
|
||||
title=f'No link with id {id} on `{page.title}`',
|
||||
text='',
|
||||
lines=[],
|
||||
links={},
|
||||
fetched_at=datetime.utcnow(),
|
||||
)
|
||||
available = sorted(page.links.keys())
|
||||
available_list = ', '.join(map(str, available)) if available else '(none)'
|
||||
err.text = '\n'.join(
|
||||
[
|
||||
f'Requested link id: {id}',
|
||||
f'Current page: {page.title}',
|
||||
f'Available link ids on this page: {available_list}',
|
||||
'',
|
||||
'Tips:',
|
||||
'- To scroll this page, call browser_open with { loc, num_lines } (no id).',
|
||||
'- To open a result from a search results page, pass the correct { cursor, id }.',
|
||||
]
|
||||
)
|
||||
err.lines = self._wrap_lines(err.text, 80)
|
||||
self._save_page(err)
|
||||
cursor = len(self.get_state().page_stack) - 1
|
||||
page_text = self._display_page(err, cursor, 0, -1)
|
||||
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
|
||||
|
||||
new_page = state.url_to_page.get(link_url)
|
||||
if not new_page:
|
||||
fetch_response = self._client.web_fetch(link_url)
|
||||
normalized: Dict[str, Any] = {
|
||||
'results': {
|
||||
link_url: [
|
||||
{
|
||||
'title': fetch_response.title or link_url,
|
||||
'url': link_url,
|
||||
'content': fetch_response.content or '',
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
new_page = self._build_page_from_fetch(link_url, normalized)
|
||||
|
||||
self._save_page(new_page)
|
||||
cursor = len(self.get_state().page_stack) - 1
|
||||
page_text = self._display_page(new_page, cursor, loc, num_lines)
|
||||
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
|
||||
|
||||
if not page:
|
||||
raise RuntimeError('No current page to display')
|
||||
|
||||
cur = self.get_state()
|
||||
cur.page_stack.append(page.url)
|
||||
self.state.set_data(cur)
|
||||
cursor = len(cur.page_stack) - 1
|
||||
page_text = self._display_page(page, cursor, loc, num_lines)
|
||||
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
|
||||
|
||||
def find(self, *, pattern: str, cursor: int = -1) -> Dict[str, Any]:
|
||||
state = self.get_state()
|
||||
if cursor == -1:
|
||||
if not state.page_stack:
|
||||
raise RuntimeError('No pages to search in')
|
||||
page = self._page_from_stack(state.page_stack[-1])
|
||||
cursor = len(state.page_stack) - 1
|
||||
else:
|
||||
if cursor < 0 or cursor >= len(state.page_stack):
|
||||
cursor = max(0, min(cursor, len(state.page_stack) - 1))
|
||||
page = self._page_from_stack(state.page_stack[cursor])
|
||||
|
||||
find_page = self._build_find_results_page(pattern, page)
|
||||
self._save_page(find_page)
|
||||
new_cursor = len(self.get_state().page_stack) - 1
|
||||
|
||||
page_text = self._display_page(find_page, new_cursor, 0, -1)
|
||||
return {'state': self.get_state(), 'pageText': cap_tool_content(page_text)}
|
||||
@ -15,6 +15,8 @@ from ollama._types import (
|
||||
ShowResponse,
|
||||
StatusResponse,
|
||||
Tool,
|
||||
WebFetchResponse,
|
||||
WebSearchResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -35,6 +37,8 @@ __all__ = [
|
||||
'ShowResponse',
|
||||
'StatusResponse',
|
||||
'Tool',
|
||||
'WebFetchResponse',
|
||||
'WebSearchResponse',
|
||||
]
|
||||
|
||||
_client = Client()
|
||||
@ -51,3 +55,5 @@ list = _client.list
|
||||
copy = _client.copy
|
||||
show = _client.show
|
||||
ps = _client.ps
|
||||
web_search = _client.web_search
|
||||
web_fetch = _client.web_fetch
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
import ipaddress
|
||||
import json
|
||||
import os
|
||||
@ -66,12 +67,16 @@ from ollama._types import (
|
||||
ShowResponse,
|
||||
StatusResponse,
|
||||
Tool,
|
||||
WebFetchRequest,
|
||||
WebFetchResponse,
|
||||
WebSearchRequest,
|
||||
WebSearchResponse,
|
||||
)
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class BaseClient:
|
||||
class BaseClient(contextlib.AbstractContextManager, contextlib.AbstractAsyncContextManager):
|
||||
def __init__(
|
||||
self,
|
||||
client,
|
||||
@ -90,23 +95,34 @@ class BaseClient:
|
||||
`kwargs` are passed to the httpx client.
|
||||
"""
|
||||
|
||||
headers = {
|
||||
k.lower(): v
|
||||
for k, v in {
|
||||
**(headers or {}),
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
'User-Agent': f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}',
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
api_key = os.getenv('OLLAMA_API_KEY', None)
|
||||
if not headers.get('authorization') and api_key:
|
||||
headers['authorization'] = f'Bearer {api_key}'
|
||||
|
||||
self._client = client(
|
||||
base_url=_parse_host(host or os.getenv('OLLAMA_HOST')),
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
# Lowercase all headers to ensure override
|
||||
headers={
|
||||
k.lower(): v
|
||||
for k, v in {
|
||||
**(headers or {}),
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
'User-Agent': f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}',
|
||||
}.items()
|
||||
},
|
||||
headers=headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
|
||||
CONNECTION_ERROR_MESSAGE = 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'
|
||||
|
||||
@ -115,6 +131,9 @@ class Client(BaseClient):
|
||||
def __init__(self, host: Optional[str] = None, **kwargs) -> None:
|
||||
super().__init__(httpx.Client, host, **kwargs)
|
||||
|
||||
def close(self):
|
||||
self._client.close()
|
||||
|
||||
def _request_raw(self, *args, **kwargs):
|
||||
try:
|
||||
r = self._client.request(*args, **kwargs)
|
||||
@ -191,6 +210,8 @@ class Client(BaseClient):
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: Literal[False] = False,
|
||||
think: Optional[bool] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
raw: bool = False,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
@ -210,6 +231,8 @@ class Client(BaseClient):
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: Literal[True] = True,
|
||||
think: Optional[bool] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
raw: bool = False,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
@ -228,6 +251,8 @@ class Client(BaseClient):
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: bool = False,
|
||||
think: Optional[bool] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
raw: Optional[bool] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
@ -257,6 +282,8 @@ class Client(BaseClient):
|
||||
context=context,
|
||||
stream=stream,
|
||||
think=think,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
raw=raw,
|
||||
format=format,
|
||||
images=list(_copy_images(images)) if images else None,
|
||||
@ -274,7 +301,9 @@ class Client(BaseClient):
|
||||
*,
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: Literal[False] = False,
|
||||
think: Optional[bool] = None,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@ -288,7 +317,9 @@ class Client(BaseClient):
|
||||
*,
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: Literal[True] = True,
|
||||
think: Optional[bool] = None,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@ -301,7 +332,9 @@ class Client(BaseClient):
|
||||
*,
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: bool = False,
|
||||
think: Optional[bool] = None,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@ -349,6 +382,8 @@ class Client(BaseClient):
|
||||
tools=list(_copy_tools(tools)),
|
||||
stream=stream,
|
||||
think=think,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
format=format,
|
||||
options=options,
|
||||
keep_alive=keep_alive,
|
||||
@ -363,6 +398,7 @@ class Client(BaseClient):
|
||||
truncate: Optional[bool] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
) -> EmbedResponse:
|
||||
return self._request(
|
||||
EmbedResponse,
|
||||
@ -374,6 +410,7 @@ class Client(BaseClient):
|
||||
truncate=truncate,
|
||||
options=options,
|
||||
keep_alive=keep_alive,
|
||||
dimensions=dimensions,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
@ -622,11 +659,62 @@ class Client(BaseClient):
|
||||
'/api/ps',
|
||||
)
|
||||
|
||||
def web_search(self, query: str, max_results: int = 3) -> WebSearchResponse:
|
||||
"""
|
||||
Performs a web search
|
||||
|
||||
Args:
|
||||
query: The query to search for
|
||||
max_results: The maximum number of results to return (default: 3)
|
||||
|
||||
Returns:
|
||||
WebSearchResponse with the search results
|
||||
Raises:
|
||||
ValueError: If OLLAMA_API_KEY environment variable is not set
|
||||
"""
|
||||
if not self._client.headers.get('authorization', '').startswith('Bearer '):
|
||||
raise ValueError('Authorization header with Bearer token is required for web search')
|
||||
|
||||
return self._request(
|
||||
WebSearchResponse,
|
||||
'POST',
|
||||
'https://ollama.com/api/web_search',
|
||||
json=WebSearchRequest(
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
def web_fetch(self, url: str) -> WebFetchResponse:
|
||||
"""
|
||||
Fetches the content of a web page for the provided URL.
|
||||
|
||||
Args:
|
||||
url: The URL to fetch
|
||||
|
||||
Returns:
|
||||
WebFetchResponse with the fetched result
|
||||
"""
|
||||
if not self._client.headers.get('authorization', '').startswith('Bearer '):
|
||||
raise ValueError('Authorization header with Bearer token is required for web fetch')
|
||||
|
||||
return self._request(
|
||||
WebFetchResponse,
|
||||
'POST',
|
||||
'https://ollama.com/api/web_fetch',
|
||||
json=WebFetchRequest(
|
||||
url=url,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
class AsyncClient(BaseClient):
|
||||
def __init__(self, host: Optional[str] = None, **kwargs) -> None:
|
||||
super().__init__(httpx.AsyncClient, host, **kwargs)
|
||||
|
||||
async def close(self):
|
||||
await self._client.aclose()
|
||||
|
||||
async def _request_raw(self, *args, **kwargs):
|
||||
try:
|
||||
r = await self._client.request(*args, **kwargs)
|
||||
@ -691,6 +779,46 @@ class AsyncClient(BaseClient):
|
||||
|
||||
return cls(**(await self._request_raw(*args, **kwargs)).json())
|
||||
|
||||
async def web_search(self, query: str, max_results: int = 3) -> WebSearchResponse:
|
||||
"""
|
||||
Performs a web search
|
||||
|
||||
Args:
|
||||
query: The query to search for
|
||||
max_results: The maximum number of results to return (default: 3)
|
||||
|
||||
Returns:
|
||||
WebSearchResponse with the search results
|
||||
"""
|
||||
return await self._request(
|
||||
WebSearchResponse,
|
||||
'POST',
|
||||
'https://ollama.com/api/web_search',
|
||||
json=WebSearchRequest(
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
async def web_fetch(self, url: str) -> WebFetchResponse:
|
||||
"""
|
||||
Fetches the content of a web page for the provided URL.
|
||||
|
||||
Args:
|
||||
url: The URL to fetch
|
||||
|
||||
Returns:
|
||||
WebFetchResponse with the fetched result
|
||||
"""
|
||||
return await self._request(
|
||||
WebFetchResponse,
|
||||
'POST',
|
||||
'https://ollama.com/api/web_fetch',
|
||||
json=WebFetchRequest(
|
||||
url=url,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
@overload
|
||||
async def generate(
|
||||
self,
|
||||
@ -702,7 +830,9 @@ class AsyncClient(BaseClient):
|
||||
template: str = '',
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: Literal[False] = False,
|
||||
think: Optional[bool] = None,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
raw: bool = False,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
@ -721,7 +851,9 @@ class AsyncClient(BaseClient):
|
||||
template: str = '',
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: Literal[True] = True,
|
||||
think: Optional[bool] = None,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
raw: bool = False,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
@ -739,7 +871,9 @@ class AsyncClient(BaseClient):
|
||||
template: Optional[str] = None,
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: bool = False,
|
||||
think: Optional[bool] = None,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
raw: Optional[bool] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
@ -768,6 +902,8 @@ class AsyncClient(BaseClient):
|
||||
context=context,
|
||||
stream=stream,
|
||||
think=think,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
raw=raw,
|
||||
format=format,
|
||||
images=list(_copy_images(images)) if images else None,
|
||||
@ -785,7 +921,9 @@ class AsyncClient(BaseClient):
|
||||
*,
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: Literal[False] = False,
|
||||
think: Optional[bool] = None,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@ -799,7 +937,9 @@ class AsyncClient(BaseClient):
|
||||
*,
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: Literal[True] = True,
|
||||
think: Optional[bool] = None,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@ -812,7 +952,9 @@ class AsyncClient(BaseClient):
|
||||
*,
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: bool = False,
|
||||
think: Optional[bool] = None,
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@ -861,6 +1003,8 @@ class AsyncClient(BaseClient):
|
||||
tools=list(_copy_tools(tools)),
|
||||
stream=stream,
|
||||
think=think,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
format=format,
|
||||
options=options,
|
||||
keep_alive=keep_alive,
|
||||
@ -875,6 +1019,7 @@ class AsyncClient(BaseClient):
|
||||
truncate: Optional[bool] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
) -> EmbedResponse:
|
||||
return await self._request(
|
||||
EmbedResponse,
|
||||
@ -886,6 +1031,7 @@ class AsyncClient(BaseClient):
|
||||
truncate=truncate,
|
||||
options=options,
|
||||
keep_alive=keep_alive,
|
||||
dimensions=dimensions,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
@ -207,9 +207,15 @@ class GenerateRequest(BaseGenerateRequest):
|
||||
images: Optional[Sequence[Image]] = None
|
||||
'Image data for multimodal models.'
|
||||
|
||||
think: Optional[bool] = None
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None
|
||||
'Enable thinking mode (for thinking models).'
|
||||
|
||||
logprobs: Optional[bool] = None
|
||||
'Return log probabilities for generated tokens.'
|
||||
|
||||
top_logprobs: Optional[int] = None
|
||||
'Number of alternative tokens and log probabilities to include per position (0-20).'
|
||||
|
||||
|
||||
class BaseGenerateResponse(SubscriptableBaseModel):
|
||||
model: Optional[str] = None
|
||||
@ -243,6 +249,19 @@ class BaseGenerateResponse(SubscriptableBaseModel):
|
||||
'Duration of evaluating inference in nanoseconds.'
|
||||
|
||||
|
||||
class TokenLogprob(SubscriptableBaseModel):
|
||||
token: str
|
||||
'Token text.'
|
||||
|
||||
logprob: float
|
||||
'Log probability for the token.'
|
||||
|
||||
|
||||
class Logprob(TokenLogprob):
|
||||
top_logprobs: Optional[Sequence[TokenLogprob]] = None
|
||||
'Most likely tokens and their log probabilities.'
|
||||
|
||||
|
||||
class GenerateResponse(BaseGenerateResponse):
|
||||
"""
|
||||
Response returned by generate requests.
|
||||
@ -257,6 +276,9 @@ class GenerateResponse(BaseGenerateResponse):
|
||||
context: Optional[Sequence[int]] = None
|
||||
'Tokenized history up to the point of the response.'
|
||||
|
||||
logprobs: Optional[Sequence[Logprob]] = None
|
||||
'Log probabilities for generated tokens.'
|
||||
|
||||
|
||||
class Message(SubscriptableBaseModel):
|
||||
"""
|
||||
@ -357,9 +379,15 @@ class ChatRequest(BaseGenerateRequest):
|
||||
tools: Optional[Sequence[Tool]] = None
|
||||
'Tools to use for the chat.'
|
||||
|
||||
think: Optional[bool] = None
|
||||
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None
|
||||
'Enable thinking mode (for thinking models).'
|
||||
|
||||
logprobs: Optional[bool] = None
|
||||
'Return log probabilities for generated tokens.'
|
||||
|
||||
top_logprobs: Optional[int] = None
|
||||
'Number of alternative tokens and log probabilities to include per position (0-20).'
|
||||
|
||||
|
||||
class ChatResponse(BaseGenerateResponse):
|
||||
"""
|
||||
@ -369,6 +397,9 @@ class ChatResponse(BaseGenerateResponse):
|
||||
message: Message
|
||||
'Response message.'
|
||||
|
||||
logprobs: Optional[Sequence[Logprob]] = None
|
||||
'Log probabilities for generated tokens if requested.'
|
||||
|
||||
|
||||
class EmbedRequest(BaseRequest):
|
||||
input: Union[str, Sequence[str]]
|
||||
@ -382,6 +413,9 @@ class EmbedRequest(BaseRequest):
|
||||
|
||||
keep_alive: Optional[Union[float, str]] = None
|
||||
|
||||
dimensions: Optional[int] = None
|
||||
'Dimensions truncates the output embedding to the specified dimension.'
|
||||
|
||||
|
||||
class EmbedResponse(BaseGenerateResponse):
|
||||
"""
|
||||
@ -538,6 +572,31 @@ class ProcessResponse(SubscriptableBaseModel):
|
||||
models: Sequence[Model]
|
||||
|
||||
|
||||
class WebSearchRequest(SubscriptableBaseModel):
|
||||
query: str
|
||||
max_results: Optional[int] = None
|
||||
|
||||
|
||||
class WebSearchResult(SubscriptableBaseModel):
|
||||
content: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class WebFetchRequest(SubscriptableBaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class WebSearchResponse(SubscriptableBaseModel):
|
||||
results: Sequence[WebSearchResult]
|
||||
|
||||
|
||||
class WebFetchResponse(SubscriptableBaseModel):
|
||||
title: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
links: Optional[Sequence[str]] = None
|
||||
|
||||
|
||||
class RequestError(Exception):
|
||||
"""
|
||||
Common class for request errors.
|
||||
|
||||
@ -37,7 +37,7 @@ dependencies = [ 'ruff>=0.9.1' ]
|
||||
config-path = 'none'
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 999
|
||||
line-length = 320
|
||||
indent-width = 2
|
||||
|
||||
[tool.ruff.format]
|
||||
|
||||
@ -61,6 +61,44 @@ def test_client_chat(httpserver: HTTPServer):
|
||||
assert response['message']['content'] == "I don't know."
|
||||
|
||||
|
||||
def test_client_chat_with_logprobs(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Hi'}],
|
||||
'tools': [],
|
||||
'stream': False,
|
||||
'logprobs': True,
|
||||
'top_logprobs': 3,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': 'Hello',
|
||||
},
|
||||
'logprobs': [
|
||||
{
|
||||
'token': 'Hello',
|
||||
'logprob': -0.1,
|
||||
'top_logprobs': [
|
||||
{'token': 'Hello', 'logprob': -0.1},
|
||||
{'token': 'Hi', 'logprob': -1.0},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Hi'}], logprobs=True, top_logprobs=3)
|
||||
assert response['logprobs'][0]['token'] == 'Hello'
|
||||
assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi'
|
||||
|
||||
|
||||
def test_client_chat_stream(httpserver: HTTPServer):
|
||||
def stream_handler(_: Request):
|
||||
def generate():
|
||||
@ -294,6 +332,40 @@ def test_client_generate(httpserver: HTTPServer):
|
||||
assert response['response'] == 'Because it is.'
|
||||
|
||||
|
||||
def test_client_generate_with_logprobs(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why',
|
||||
'stream': False,
|
||||
'logprobs': True,
|
||||
'top_logprobs': 2,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'response': 'Hello',
|
||||
'logprobs': [
|
||||
{
|
||||
'token': 'Hello',
|
||||
'logprob': -0.2,
|
||||
'top_logprobs': [
|
||||
{'token': 'Hello', 'logprob': -0.2},
|
||||
{'token': 'Hi', 'logprob': -1.5},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy', 'Why', logprobs=True, top_logprobs=2)
|
||||
assert response['logprobs'][0]['token'] == 'Hello'
|
||||
assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi'
|
||||
|
||||
|
||||
def test_client_generate_with_image_type(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
@ -1195,3 +1267,113 @@ async def test_arbitrary_roles_accepted_in_message_request_async(monkeypatch: py
|
||||
client = AsyncClient()
|
||||
|
||||
await client.chat(model='llama3.1', messages=[{'role': 'somerandomrole', 'content': "I'm ok with you adding any role message now!"}, {'role': 'user', 'content': 'Hello world!'}])
|
||||
|
||||
|
||||
def test_client_web_search_requires_bearer_auth_header(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.delenv('OLLAMA_API_KEY', raising=False)
|
||||
|
||||
client = Client()
|
||||
|
||||
with pytest.raises(ValueError, match='Authorization header with Bearer token is required for web search'):
|
||||
client.web_search('test query')
|
||||
|
||||
|
||||
def test_client_web_fetch_requires_bearer_auth_header(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.delenv('OLLAMA_API_KEY', raising=False)
|
||||
|
||||
client = Client()
|
||||
|
||||
with pytest.raises(ValueError, match='Authorization header with Bearer token is required for web fetch'):
|
||||
client.web_fetch('https://example.com')
|
||||
|
||||
|
||||
def _mock_request_web_search(self, cls, method, url, json=None, **kwargs):
|
||||
assert method == 'POST'
|
||||
assert url == 'https://ollama.com/api/web_search'
|
||||
assert json is not None and 'query' in json and 'max_results' in json
|
||||
return httpxResponse(status_code=200, content='{"results": {}, "success": true}')
|
||||
|
||||
|
||||
def _mock_request_web_fetch(self, cls, method, url, json=None, **kwargs):
|
||||
assert method == 'POST'
|
||||
assert url == 'https://ollama.com/api/web_fetch'
|
||||
assert json is not None and 'url' in json
|
||||
return httpxResponse(status_code=200, content='{"results": {}, "success": true}')
|
||||
|
||||
|
||||
def test_client_web_search_with_env_api_key(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv('OLLAMA_API_KEY', 'test-key')
|
||||
monkeypatch.setattr(Client, '_request', _mock_request_web_search)
|
||||
|
||||
client = Client()
|
||||
client.web_search('what is ollama?', max_results=2)
|
||||
|
||||
|
||||
def test_client_web_fetch_with_env_api_key(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv('OLLAMA_API_KEY', 'test-key')
|
||||
monkeypatch.setattr(Client, '_request', _mock_request_web_fetch)
|
||||
|
||||
client = Client()
|
||||
client.web_fetch('https://example.com')
|
||||
|
||||
|
||||
def test_client_web_search_with_explicit_bearer_header(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.delenv('OLLAMA_API_KEY', raising=False)
|
||||
monkeypatch.setattr(Client, '_request', _mock_request_web_search)
|
||||
|
||||
client = Client(headers={'Authorization': 'Bearer custom-token'})
|
||||
client.web_search('what is ollama?', max_results=1)
|
||||
|
||||
|
||||
def test_client_web_fetch_with_explicit_bearer_header(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.delenv('OLLAMA_API_KEY', raising=False)
|
||||
monkeypatch.setattr(Client, '_request', _mock_request_web_fetch)
|
||||
|
||||
client = Client(headers={'Authorization': 'Bearer custom-token'})
|
||||
client.web_fetch('https://example.com')
|
||||
|
||||
|
||||
def test_client_bearer_header_from_env(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv('OLLAMA_API_KEY', 'env-token')
|
||||
|
||||
client = Client()
|
||||
assert client._client.headers['authorization'] == 'Bearer env-token'
|
||||
|
||||
|
||||
def test_client_explicit_bearer_header_overrides_env(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv('OLLAMA_API_KEY', 'env-token')
|
||||
monkeypatch.setattr(Client, '_request', _mock_request_web_search)
|
||||
|
||||
client = Client(headers={'Authorization': 'Bearer explicit-token'})
|
||||
assert client._client.headers['authorization'] == 'Bearer explicit-token'
|
||||
client.web_search('override check')
|
||||
|
||||
|
||||
def test_client_close():
|
||||
client = Client()
|
||||
client.close()
|
||||
assert client._client.is_closed
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_async_client_close():
|
||||
client = AsyncClient()
|
||||
await client.close()
|
||||
assert client._client.is_closed
|
||||
|
||||
|
||||
def test_client_context_manager():
|
||||
with Client() as client:
|
||||
assert isinstance(client, Client)
|
||||
assert not client._client.is_closed
|
||||
|
||||
assert client._client.is_closed
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_async_client_context_manager():
|
||||
async with AsyncClient() as client:
|
||||
assert isinstance(client, AsyncClient)
|
||||
assert not client._client.is_closed
|
||||
|
||||
assert client._client.is_closed
|
||||
|
||||
Loading…
Reference in New Issue
Block a user