Compare commits

..

19 Commits

Author SHA1 Message Date
Jeffrey Morgan 1066246ab5 fix validation of format field to allow empty strings as it did previously (#369) 2024-12-07 19:37:55 -08:00
Parth Sareen 4b10dee2b2 Structured outputs support with examples (#354) 2024-12-05 15:40:49 -08:00
dependabot[bot] e956a331e8 Merge pull request #358 from ollama/dependabot/pip/pytest-8.3.4 2024-12-03 21:15:57 +00:00
dependabot[bot] 12f7302d5f Bump pytest from 8.3.3 to 8.3.4
Bumps [pytest](https://github.com/pytest-dev/pytest) from 8.3.3 to 8.3.4.
- [Release notes](https://github.com/pytest-dev/pytest/releases)
- [Changelog](https://github.com/pytest-dev/pytest/blob/main/CHANGELOG.rst)
- [Commits](https://github.com/pytest-dev/pytest/compare/8.3.3...8.3.4)

---
updated-dependencies:
- dependency-name: pytest
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-12-02 23:33:31 +00:00
Parth Sareen 366180aa8f Improve tool example to showcase chatting (#352) 2024-11-29 20:34:19 -08:00
Parth Sareen d6528cf731 Fix image serialization for long image string (#348) 2024-11-28 14:12:55 -08:00
Julia Scheaffer b50a65b27d Add Callable type annotation for Tools (#344) 2024-11-27 09:53:26 -08:00
Jeffrey Morgan 758a1d2933 make subscription methods more consistent 2024-11-26 10:52:45 -08:00
Jeffrey Morgan d4c38978d1 Update ollama/_types.py
Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
2024-11-26 10:41:53 -08:00
Jeffrey Morgan d8d98e17b2 Update ollama/_types.py
Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
2024-11-26 10:41:50 -08:00
Jeffrey Morgan ec2c8fdd8d Update ollama/_types.py
Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
2024-11-26 10:41:45 -08:00
Jeffrey Morgan ea0e0dc692 Update ollama/_types.py
Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
2024-11-26 10:35:57 -08:00
Shahil Yadav 6c44bb2729 Fix chat-with-history.py example (#337)
Fix chat-with-history.py example

---------

Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
2024-11-24 10:49:46 -08:00
jmorganca 2095fc9107 make subscription methods more consistent with maps 2024-11-23 19:02:28 -08:00
Jeffrey Morgan 64e3723e6b Merge pull request #334 from ollama/mxyng/hasattr-none 2024-11-23 18:27:22 -08:00
jmorganca 1e22f2e118 add test case for explicit None 2024-11-23 18:22:38 -08:00
jmorganca 00c64332cc check defaults that aren't None too 2024-11-23 18:15:05 -08:00
Michael Yang 986fb4c7b3 fix: skip tests on examples, readme 2024-11-23 16:44:09 -08:00
Michael Yang c6ade633b8 fix: hasattr checks if attr is None 2024-11-23 16:44:09 -08:00
14 changed files with 494 additions and 47 deletions
+1 -1
View File
@@ -2,7 +2,7 @@ name: test
on:
pull_request:
paths:
paths-ignore:
- 'examples/**'
- '**/README.md'
-3
View File
@@ -37,9 +37,6 @@ See [_types.py](ollama/_types.py) for more information on the response types.
Response streaming can be enabled by setting `stream=True`.
> [!NOTE]
> Streaming Tool/Function calling is not yet supported.
```python
from ollama import chat
+6
View File
@@ -30,6 +30,12 @@ python3 examples/<example>.py
- [multimodal_generate.py](multimodal_generate.py)
### Structured Outputs - Generate structured outputs with a model
- [structured-outputs.py](structured-outputs.py)
- [async-structured-outputs.py](async-structured-outputs.py)
- [structured-outputs-image.py](structured-outputs-image.py)
### Ollama List - List all downloaded models and their properties
- [list.py](list.py)
+32
View File
@@ -0,0 +1,32 @@
from pydantic import BaseModel
from ollama import AsyncClient
import asyncio
# Define the schema for the response
class FriendInfo(BaseModel):
name: str
age: int
is_available: bool
class FriendList(BaseModel):
friends: list[FriendInfo]
async def main():
client = AsyncClient()
response = await client.chat(
model='llama3.1:8b',
messages=[{'role': 'user', 'content': 'I have two friends. The first is Ollama 22 years old busy saving the world, and the second is Alonso 23 years old and wants to hang out. Return a list of friends in JSON format'}],
format=FriendList.model_json_schema(), # Use Pydantic to generate the schema
options={'temperature': 0}, # Make responses more deterministic
)
# Use Pydantic to validate the response
friends_response = FriendList.model_validate_json(response.message.content)
print(friends_response)
if __name__ == '__main__':
asyncio.run(main())
+24 -10
View File
@@ -41,21 +41,21 @@ subtract_two_numbers_tool = {
},
}
messages = [{'role': 'user', 'content': 'What is three plus one?'}]
print('Prompt:', messages[0]['content'])
available_functions = {
'add_two_numbers': add_two_numbers,
'subtract_two_numbers': subtract_two_numbers,
}
async def main():
client = ollama.AsyncClient()
prompt = 'What is three plus one?'
print('Prompt:', prompt)
available_functions = {
'add_two_numbers': add_two_numbers,
'subtract_two_numbers': subtract_two_numbers,
}
response: ChatResponse = await client.chat(
'llama3.1',
messages=[{'role': 'user', 'content': prompt}],
messages=messages,
tools=[add_two_numbers, subtract_two_numbers_tool],
)
@@ -66,10 +66,24 @@ async def main():
if function_to_call := available_functions.get(tool.function.name):
print('Calling function:', tool.function.name)
print('Arguments:', tool.function.arguments)
print('Function output:', function_to_call(**tool.function.arguments))
output = function_to_call(**tool.function.arguments)
print('Function output:', output)
else:
print('Function', tool.function.name, 'not found')
# Only needed to chat with the model using the tool call results
if response.message.tool_calls:
# Add the function response to messages for the model to use
messages.append(response.message)
messages.append({'role': 'tool', 'content': str(output), 'name': tool.function.name})
# Get final response from model with function outputs
final_response = await client.chat('llama3.1', messages=messages)
print('Final response:', final_response.message.content)
else:
print('No tool calls returned from model')
if __name__ == '__main__':
try:
+2 -2
View File
@@ -31,8 +31,8 @@ while True:
)
# Add the response to the messages to maintain the history
messages.append(
messages += [
{'role': 'user', 'content': user_input},
{'role': 'assistant', 'content': response.message.content},
)
]
print(response.message.content + '\n')
+50
View File
@@ -0,0 +1,50 @@
from pathlib import Path
from pydantic import BaseModel
from typing import List, Optional, Literal
from ollama import chat
from rich import print
# Define the schema for image objects
class Object(BaseModel):
name: str
confidence: float
attributes: Optional[dict] = None
class ImageDescription(BaseModel):
summary: str
objects: List[Object]
scene: str
colors: List[str]
time_of_day: Literal['Morning', 'Afternoon', 'Evening', 'Night']
setting: Literal['Indoor', 'Outdoor', 'Unknown']
text_content: Optional[str] = None
# Get path from user input
path = input('Enter the path to your image: ')
path = Path(path)
# Verify the file exists
if not path.exists():
raise FileNotFoundError(f'Image not found at: {path}')
# Set up chat as usual
response = chat(
model='llama3.2-vision',
format=ImageDescription.model_json_schema(), # Pass in the schema for the response
messages=[
{
'role': 'user',
'content': 'Analyze this image and return a detailed JSON description including objects, scene, colors and any text detected. If you cannot determine certain details, leave those fields empty.',
'images': [path],
},
],
options={'temperature': 0}, # Set temperature to 0 for more deterministic output
)
# Convert received content to the schema
image_analysis = ImageDescription.model_validate_json(response.message.content)
print(image_analysis)
+26
View File
@@ -0,0 +1,26 @@
from ollama import chat
from pydantic import BaseModel
# Define the schema for the response
class FriendInfo(BaseModel):
name: str
age: int
is_available: bool
class FriendList(BaseModel):
friends: list[FriendInfo]
# schema = {'type': 'object', 'properties': {'friends': {'type': 'array', 'items': {'type': 'object', 'properties': {'name': {'type': 'string'}, 'age': {'type': 'integer'}, 'is_available': {'type': 'boolean'}}, 'required': ['name', 'age', 'is_available']}}}, 'required': ['friends']}
response = chat(
model='llama3.1:8b',
messages=[{'role': 'user', 'content': 'I have two friends. The first is Ollama 22 years old busy saving the world, and the second is Alonso 23 years old and wants to hang out. Return a list of friends in JSON format'}],
format=FriendList.model_json_schema(), # Use Pydantic to generate the schema or format=schema
options={'temperature': 0}, # Make responses more deterministic
)
# Use Pydantic to validate the response
friends_response = FriendList.model_validate_json(response.message.content)
print(friends_response)
+18 -4
View File
@@ -40,8 +40,8 @@ subtract_two_numbers_tool = {
},
}
prompt = 'What is three plus one?'
print('Prompt:', prompt)
messages = [{'role': 'user', 'content': 'What is three plus one?'}]
print('Prompt:', messages[0]['content'])
available_functions = {
'add_two_numbers': add_two_numbers,
@@ -50,7 +50,7 @@ available_functions = {
response: ChatResponse = chat(
'llama3.1',
messages=[{'role': 'user', 'content': prompt}],
messages=messages,
tools=[add_two_numbers, subtract_two_numbers_tool],
)
@@ -61,6 +61,20 @@ if response.message.tool_calls:
if function_to_call := available_functions.get(tool.function.name):
print('Calling function:', tool.function.name)
print('Arguments:', tool.function.arguments)
print('Function output:', function_to_call(**tool.function.arguments))
output = function_to_call(**tool.function.arguments)
print('Function output:', output)
else:
print('Function', tool.function.name, 'not found')
# Only needed to chat with the model using the tool call results
if response.message.tool_calls:
# Add the function response to messages for the model to use
messages.append(response.message)
messages.append({'role': 'tool', 'content': str(output), 'name': tool.function.name})
# Get final response from model with function outputs
final_response = chat('llama3.1', messages=messages)
print('Final response:', final_response.message.content)
else:
print('No tool calls returned from model')
+18 -17
View File
@@ -23,6 +23,8 @@ from typing import (
import sys
from pydantic.json_schema import JsonSchemaValue
from ollama._utils import convert_function_to_tool
@@ -186,7 +188,7 @@ class Client(BaseClient):
context: Optional[Sequence[int]] = None,
stream: Literal[False] = False,
raw: bool = False,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
@@ -204,7 +206,7 @@ class Client(BaseClient):
context: Optional[Sequence[int]] = None,
stream: Literal[True] = True,
raw: bool = False,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
@@ -221,7 +223,7 @@ class Client(BaseClient):
context: Optional[Sequence[int]] = None,
stream: bool = False,
raw: Optional[bool] = None,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
@@ -263,9 +265,9 @@ class Client(BaseClient):
model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[False] = False,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> ChatResponse: ...
@@ -276,9 +278,9 @@ class Client(BaseClient):
model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[True] = True,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Iterator[ChatResponse]: ...
@@ -290,7 +292,7 @@ class Client(BaseClient):
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: bool = False,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[ChatResponse, Iterator[ChatResponse]]:
@@ -327,7 +329,6 @@ class Client(BaseClient):
Returns `ChatResponse` if `stream` is `False`, otherwise returns a `ChatResponse` generator.
"""
return self._request(
ChatResponse,
'POST',
@@ -689,7 +690,7 @@ class AsyncClient(BaseClient):
context: Optional[Sequence[int]] = None,
stream: Literal[False] = False,
raw: bool = False,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
@@ -707,7 +708,7 @@ class AsyncClient(BaseClient):
context: Optional[Sequence[int]] = None,
stream: Literal[True] = True,
raw: bool = False,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
@@ -724,7 +725,7 @@ class AsyncClient(BaseClient):
context: Optional[Sequence[int]] = None,
stream: bool = False,
raw: Optional[bool] = None,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
@@ -765,9 +766,9 @@ class AsyncClient(BaseClient):
model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[False] = False,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> ChatResponse: ...
@@ -780,7 +781,7 @@ class AsyncClient(BaseClient):
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[True] = True,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> AsyncIterator[ChatResponse]: ...
@@ -790,9 +791,9 @@ class AsyncClient(BaseClient):
model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: bool = False,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[ChatResponse, AsyncIterator[ChatResponse]]:
+83 -6
View File
@@ -4,6 +4,7 @@ from pathlib import Path
from datetime import datetime
from typing import Any, Mapping, Optional, Union, Sequence
from pydantic.json_schema import JsonSchemaValue
from typing_extensions import Annotated, Literal
from pydantic import (
@@ -17,16 +18,87 @@ from pydantic import (
class SubscriptableBaseModel(BaseModel):
def __getitem__(self, key: str) -> Any:
return getattr(self, key)
"""
>>> msg = Message(role='user')
>>> msg['role']
'user'
>>> msg = Message(role='user')
>>> msg['nonexistent']
Traceback (most recent call last):
KeyError: 'nonexistent'
"""
if key in self:
return getattr(self, key)
raise KeyError(key)
def __setitem__(self, key: str, value: Any) -> None:
"""
>>> msg = Message(role='user')
>>> msg['role'] = 'assistant'
>>> msg['role']
'assistant'
>>> tool_call = Message.ToolCall(function=Message.ToolCall.Function(name='foo', arguments={}))
>>> msg = Message(role='user', content='hello')
>>> msg['tool_calls'] = [tool_call]
>>> msg['tool_calls'][0]['function']['name']
'foo'
"""
setattr(self, key, value)
def __contains__(self, key: str) -> bool:
return hasattr(self, key)
"""
>>> msg = Message(role='user')
>>> 'nonexistent' in msg
False
>>> 'role' in msg
True
>>> 'content' in msg
False
>>> msg.content = 'hello!'
>>> 'content' in msg
True
>>> msg = Message(role='user', content='hello!')
>>> 'content' in msg
True
>>> 'tool_calls' in msg
False
>>> msg['tool_calls'] = []
>>> 'tool_calls' in msg
True
>>> msg['tool_calls'] = [Message.ToolCall(function=Message.ToolCall.Function(name='foo', arguments={}))]
>>> 'tool_calls' in msg
True
>>> msg['tool_calls'] = None
>>> 'tool_calls' in msg
True
>>> tool = Tool()
>>> 'type' in tool
True
"""
if key in self.model_fields_set:
return True
if key in self.model_fields:
return self.model_fields[key].default is not None
return False
def get(self, key: str, default: Any = None) -> Any:
return getattr(self, key, default)
"""
>>> msg = Message(role='user')
>>> msg.get('role')
'user'
>>> msg = Message(role='user')
>>> msg.get('nonexistent')
>>> msg = Message(role='user')
>>> msg.get('nonexistent', 'default')
'default'
>>> msg = Message(role='user', tool_calls=[ Message.ToolCall(function=Message.ToolCall.Function(name='foo', arguments={}))])
>>> msg.get('tool_calls')[0]['function']['name']
'foo'
"""
return self[key] if key in self else default
class Options(SubscriptableBaseModel):
@@ -79,7 +151,7 @@ class BaseGenerateRequest(BaseStreamableRequest):
options: Optional[Union[Mapping[str, Any], Options]] = None
'Options to use for the request.'
format: Optional[Literal['', 'json']] = None
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None
'Format of the response.'
keep_alive: Optional[Union[float, str]] = None
@@ -95,9 +167,14 @@ class Image(BaseModel):
return b64encode(self.value.read_bytes() if isinstance(self.value, Path) else self.value).decode()
if isinstance(self.value, str):
if Path(self.value).exists():
return b64encode(Path(self.value).read_bytes()).decode()
try:
if Path(self.value).exists():
return b64encode(Path(self.value).read_bytes()).decode()
except Exception:
# Long base64 string can't be wrapped in Path, so try to treat as base64 string
pass
# String might be a file path, but might not exist
if self.value.split('.')[-1] in ('png', 'jpg', 'jpeg', 'webp'):
raise ValueError(f'File {self.value} does not exist')
Generated
+3 -3
View File
@@ -559,13 +559,13 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
[[package]]
name = "pytest"
version = "8.3.3"
version = "8.3.4"
description = "pytest: simple powerful testing with Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2"},
{file = "pytest-8.3.3.tar.gz", hash = "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181"},
{file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"},
{file = "pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761"},
]
[package.dependencies]
+225 -1
View File
@@ -1,7 +1,7 @@
import os
import io
import json
from pydantic import ValidationError
from pydantic import ValidationError, BaseModel
import pytest
import tempfile
from pathlib import Path
@@ -122,6 +122,128 @@ def test_client_chat_images(httpserver: HTTPServer):
assert response['message']['content'] == "I don't know."
def test_client_chat_format_json(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/chat',
method='POST',
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'format': 'json',
'stream': False,
},
).respond_with_json(
{
'model': 'dummy',
'message': {
'role': 'assistant',
'content': '{"answer": "Because of Rayleigh scattering"}',
},
}
)
client = Client(httpserver.url_for('/'))
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format='json')
assert response['model'] == 'dummy'
assert response['message']['role'] == 'assistant'
assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering"}'
def test_client_chat_format_pydantic(httpserver: HTTPServer):
class ResponseFormat(BaseModel):
answer: str
confidence: float
httpserver.expect_ordered_request(
'/api/chat',
method='POST',
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']},
'stream': False,
},
).respond_with_json(
{
'model': 'dummy',
'message': {
'role': 'assistant',
'content': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}',
},
}
)
client = Client(httpserver.url_for('/'))
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format=ResponseFormat.model_json_schema())
assert response['model'] == 'dummy'
assert response['message']['role'] == 'assistant'
assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
@pytest.mark.asyncio
async def test_async_client_chat_format_json(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/chat',
method='POST',
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'format': 'json',
'stream': False,
},
).respond_with_json(
{
'model': 'dummy',
'message': {
'role': 'assistant',
'content': '{"answer": "Because of Rayleigh scattering"}',
},
}
)
client = AsyncClient(httpserver.url_for('/'))
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format='json')
assert response['model'] == 'dummy'
assert response['message']['role'] == 'assistant'
assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering"}'
@pytest.mark.asyncio
async def test_async_client_chat_format_pydantic(httpserver: HTTPServer):
class ResponseFormat(BaseModel):
answer: str
confidence: float
httpserver.expect_ordered_request(
'/api/chat',
method='POST',
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']},
'stream': False,
},
).respond_with_json(
{
'model': 'dummy',
'message': {
'role': 'assistant',
'content': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}',
},
}
)
client = AsyncClient(httpserver.url_for('/'))
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format=ResponseFormat.model_json_schema())
assert response['model'] == 'dummy'
assert response['message']['role'] == 'assistant'
assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
def test_client_generate(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
@@ -205,6 +327,108 @@ def test_client_generate_images(httpserver: HTTPServer):
assert response['response'] == 'Because it is.'
def test_client_generate_format_json(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'format': 'json',
'stream': False,
},
).respond_with_json(
{
'model': 'dummy',
'response': '{"answer": "Because of Rayleigh scattering"}',
}
)
client = Client(httpserver.url_for('/'))
response = client.generate('dummy', 'Why is the sky blue?', format='json')
assert response['model'] == 'dummy'
assert response['response'] == '{"answer": "Because of Rayleigh scattering"}'
def test_client_generate_format_pydantic(httpserver: HTTPServer):
class ResponseFormat(BaseModel):
answer: str
confidence: float
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']},
'stream': False,
},
).respond_with_json(
{
'model': 'dummy',
'response': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}',
}
)
client = Client(httpserver.url_for('/'))
response = client.generate('dummy', 'Why is the sky blue?', format=ResponseFormat.model_json_schema())
assert response['model'] == 'dummy'
assert response['response'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
@pytest.mark.asyncio
async def test_async_client_generate_format_json(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'format': 'json',
'stream': False,
},
).respond_with_json(
{
'model': 'dummy',
'response': '{"answer": "Because of Rayleigh scattering"}',
}
)
client = AsyncClient(httpserver.url_for('/'))
response = await client.generate('dummy', 'Why is the sky blue?', format='json')
assert response['model'] == 'dummy'
assert response['response'] == '{"answer": "Because of Rayleigh scattering"}'
@pytest.mark.asyncio
async def test_async_client_generate_format_pydantic(httpserver: HTTPServer):
class ResponseFormat(BaseModel):
answer: str
confidence: float
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']},
'stream': False,
},
).respond_with_json(
{
'model': 'dummy',
'response': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}',
}
)
client = AsyncClient(httpserver.url_for('/'))
response = await client.generate('dummy', 'Why is the sky blue?', format=ResponseFormat.model_json_schema())
assert response['model'] == 'dummy'
assert response['response'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
def test_client_pull(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/pull',
+6
View File
@@ -19,6 +19,12 @@ def test_image_serialization_base64_string():
assert img.model_dump() == b64_str # Should return as-is if valid base64
def test_image_serialization_long_base64_string():
b64_str = 'dGVzdCBiYXNlNjQgc3RyaW5n' * 1000
img = Image(value=b64_str)
assert img.model_dump() == b64_str # Should return as-is if valid base64
def test_image_serialization_plain_string():
img = Image(value='not a path or base64')
assert img.model_dump() == 'not a path or base64' # Should return as-is