fix types, client, example

This commit is contained in:
ParthSareen 2025-09-23 13:22:45 -07:00
parent 9d051dccae
commit 79eda0c2f5
3 changed files with 31 additions and 38 deletions

View File

@ -5,54 +5,51 @@
# "ollama",
# ]
# ///
import os
from typing import Union
from rich import print
from ollama import Client, WebFetchResponse, WebSearchResponse
from ollama import WebFetchResponse, WebSearchResponse, chat, web_fetch, web_search
def format_tool_results(
results: Union[WebSearchResponse, WebFetchResponse],
user_search: str,
):
output = []
if isinstance(results, WebSearchResponse):
output = []
output.append(f'Search results for "{user_search}":')
for i, result in enumerate(results.results, 1):
title = getattr(result, 'title', None)
url_value = getattr(result, 'url', None)
output.append(f'{i}. {title}' if title else f'{i}. {getattr(result, "content", "")}')
if url_value:
output.append(f' URL: {url_value}')
output.append(f' Content: {getattr(result, "content", "")}')
for result in results.results:
output.append(f'{result.title}' if result.title else f'{result.content}')
output.append(f' URL: {result.url}')
output.append(f' Content: {result.content}')
output.append('')
return '\n'.join(output).rstrip()
elif isinstance(results, WebFetchResponse):
output = []
output.append(f'Fetch results for "{user_search}":')
output.extend([
f'Title: {results.title}',
f'URL: {user_search}' if user_search else '',
f'Content: {results.content}',
])
output.extend(
[
f'Title: {results.title}',
f'URL: {user_search}' if user_search else '',
f'Content: {results.content}',
]
)
if results.links:
output.append(f'Links: {", ".join(results.links)}')
output.append('')
return '\n'.join(output).rstrip()
api_key = os.getenv('OLLAMA_API_KEY')
client = Client(headers={'Authorization': f"Bearer {s.getenv('OLLAMA_API_KEY')}"} if api_key else None)
available_tools = {'web_search': client.web_search, 'web_fetch': client.web_fetch}
query = "ollama's new engine"
# client = Client(headers={'Authorization': f"Bearer {os.getenv('OLLAMA_API_KEY')}"} if api_key else None)
available_tools = {'web_search': web_search, 'web_fetch': web_fetch}
query = "what is ollama's new engine"
print('Query: ', query)
messages = [{'role': 'user', 'content': query}]
while True:
response = client.chat(model='qwen3', messages=messages, tools=[client.web_search, client.web_fetch], think=True)
response = chat(model='qwen3', messages=messages, tools=[web_search, web_fetch], think=True)
if response.message.thinking:
print('Thinking: ')
print(response.message.thinking + '\n\n')
@ -73,15 +70,9 @@ while True:
print()
user_search = args.get('query', '') or args.get('url', '')
if tool_call.function.name == 'web_search':
formatted_tool_results = format_tool_results(result, user_search=user_search)
elif tool_call.function.name == 'web_fetch':
formatted_tool_results = format_tool_results(result, user_search=user_search)
else:
formatted_tool_results = format_tool_results(result)
formatted_tool_results = format_tool_results(result, user_search=user_search)
print('Result:')
print(formatted_tool_results[:200])
print(formatted_tool_results[:300])
print()
# caps the result at ~2000 tokens

View File

@ -639,7 +639,7 @@ class Client(BaseClient):
Args:
query: The query to search for
max_results: The maximum number of results to return.
max_results: The maximum number of results to return (default: 3)
Returns:
WebSearchResponse with the search results
@ -750,13 +750,13 @@ class AsyncClient(BaseClient):
return cls(**(await self._request_raw(*args, **kwargs)).json())
async def websearch(self, query: str, max_results: int = 3) -> WebSearchResponse:
async def web_search(self, query: str, max_results: int = 3) -> WebSearchResponse:
"""
Performs a web search
Args:
query: The query to search for
max_results: The maximum number of results to return.
max_results: The maximum number of results to return (default: 3)
Returns:
WebSearchResponse with the search results
@ -764,14 +764,14 @@ class AsyncClient(BaseClient):
return await self._request(
WebSearchResponse,
'POST',
'http://localhost:8080/api/web_search',
'https://ollama.com/api/web_search',
json=WebSearchRequest(
query=query,
max_results=max_results,
).model_dump(exclude_none=True),
)
async def webfetch(self, url: str) -> WebFetchResponse:
async def web_fetch(self, url: str) -> WebFetchResponse:
"""
Fetches the content of a web page for the provided URL.
@ -784,7 +784,7 @@ class AsyncClient(BaseClient):
return await self._request(
WebFetchResponse,
'POST',
'http://localhost:8080/api/web_fetch',
'https://ollama.com/api/web_fetch',
json=WebFetchRequest(
url=url,
).model_dump(exclude_none=True),

View File

@ -547,7 +547,9 @@ class WebSearchRequest(SubscriptableBaseModel):
class WebSearchResult(SubscriptableBaseModel):
content: str
content: Optional[str] = None
title: Optional[str] = None
url: Optional[str] = None
class WebFetchRequest(SubscriptableBaseModel):
@ -559,8 +561,8 @@ class WebSearchResponse(SubscriptableBaseModel):
class WebFetchResponse(SubscriptableBaseModel):
title: str
content: str
title: Optional[str] = None
content: Optional[str] = None
links: Optional[Sequence[str]] = None