mirror of
https://github.com/ollama/ollama-python.git
synced 2026-06-15 20:54:51 +00:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1066246ab5 | |||
| 4b10dee2b2 | |||
| e956a331e8 | |||
| 12f7302d5f | |||
| 366180aa8f | |||
| d6528cf731 | |||
| b50a65b27d | |||
| 758a1d2933 | |||
| d4c38978d1 | |||
| d8d98e17b2 | |||
| ec2c8fdd8d | |||
| ea0e0dc692 | |||
| 6c44bb2729 | |||
| 2095fc9107 | |||
| 64e3723e6b | |||
| 1e22f2e118 | |||
| 00c64332cc | |||
| 986fb4c7b3 | |||
| c6ade633b8 |
@@ -2,7 +2,7 @@ name: test
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
paths-ignore:
|
||||
- 'examples/**'
|
||||
- '**/README.md'
|
||||
|
||||
|
||||
@@ -37,9 +37,6 @@ See [_types.py](ollama/_types.py) for more information on the response types.
|
||||
|
||||
Response streaming can be enabled by setting `stream=True`.
|
||||
|
||||
> [!NOTE]
|
||||
> Streaming Tool/Function calling is not yet supported.
|
||||
|
||||
```python
|
||||
from ollama import chat
|
||||
|
||||
|
||||
@@ -30,6 +30,12 @@ python3 examples/<example>.py
|
||||
- [multimodal_generate.py](multimodal_generate.py)
|
||||
|
||||
|
||||
### Structured Outputs - Generate structured outputs with a model
|
||||
- [structured-outputs.py](structured-outputs.py)
|
||||
- [async-structured-outputs.py](async-structured-outputs.py)
|
||||
- [structured-outputs-image.py](structured-outputs-image.py)
|
||||
|
||||
|
||||
### Ollama List - List all downloaded models and their properties
|
||||
- [list.py](list.py)
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
from pydantic import BaseModel
|
||||
from ollama import AsyncClient
|
||||
import asyncio
|
||||
|
||||
|
||||
# Define the schema for the response
|
||||
class FriendInfo(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
is_available: bool
|
||||
|
||||
|
||||
class FriendList(BaseModel):
|
||||
friends: list[FriendInfo]
|
||||
|
||||
|
||||
async def main():
|
||||
client = AsyncClient()
|
||||
response = await client.chat(
|
||||
model='llama3.1:8b',
|
||||
messages=[{'role': 'user', 'content': 'I have two friends. The first is Ollama 22 years old busy saving the world, and the second is Alonso 23 years old and wants to hang out. Return a list of friends in JSON format'}],
|
||||
format=FriendList.model_json_schema(), # Use Pydantic to generate the schema
|
||||
options={'temperature': 0}, # Make responses more deterministic
|
||||
)
|
||||
|
||||
# Use Pydantic to validate the response
|
||||
friends_response = FriendList.model_validate_json(response.message.content)
|
||||
print(friends_response)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
+24
-10
@@ -41,21 +41,21 @@ subtract_two_numbers_tool = {
|
||||
},
|
||||
}
|
||||
|
||||
messages = [{'role': 'user', 'content': 'What is three plus one?'}]
|
||||
print('Prompt:', messages[0]['content'])
|
||||
|
||||
available_functions = {
|
||||
'add_two_numbers': add_two_numbers,
|
||||
'subtract_two_numbers': subtract_two_numbers,
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
client = ollama.AsyncClient()
|
||||
|
||||
prompt = 'What is three plus one?'
|
||||
print('Prompt:', prompt)
|
||||
|
||||
available_functions = {
|
||||
'add_two_numbers': add_two_numbers,
|
||||
'subtract_two_numbers': subtract_two_numbers,
|
||||
}
|
||||
|
||||
response: ChatResponse = await client.chat(
|
||||
'llama3.1',
|
||||
messages=[{'role': 'user', 'content': prompt}],
|
||||
messages=messages,
|
||||
tools=[add_two_numbers, subtract_two_numbers_tool],
|
||||
)
|
||||
|
||||
@@ -66,10 +66,24 @@ async def main():
|
||||
if function_to_call := available_functions.get(tool.function.name):
|
||||
print('Calling function:', tool.function.name)
|
||||
print('Arguments:', tool.function.arguments)
|
||||
print('Function output:', function_to_call(**tool.function.arguments))
|
||||
output = function_to_call(**tool.function.arguments)
|
||||
print('Function output:', output)
|
||||
else:
|
||||
print('Function', tool.function.name, 'not found')
|
||||
|
||||
# Only needed to chat with the model using the tool call results
|
||||
if response.message.tool_calls:
|
||||
# Add the function response to messages for the model to use
|
||||
messages.append(response.message)
|
||||
messages.append({'role': 'tool', 'content': str(output), 'name': tool.function.name})
|
||||
|
||||
# Get final response from model with function outputs
|
||||
final_response = await client.chat('llama3.1', messages=messages)
|
||||
print('Final response:', final_response.message.content)
|
||||
|
||||
else:
|
||||
print('No tool calls returned from model')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
|
||||
@@ -31,8 +31,8 @@ while True:
|
||||
)
|
||||
|
||||
# Add the response to the messages to maintain the history
|
||||
messages.append(
|
||||
messages += [
|
||||
{'role': 'user', 'content': user_input},
|
||||
{'role': 'assistant', 'content': response.message.content},
|
||||
)
|
||||
]
|
||||
print(response.message.content + '\n')
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Literal
|
||||
from ollama import chat
|
||||
from rich import print
|
||||
|
||||
|
||||
# Define the schema for image objects
|
||||
class Object(BaseModel):
|
||||
name: str
|
||||
confidence: float
|
||||
attributes: Optional[dict] = None
|
||||
|
||||
|
||||
class ImageDescription(BaseModel):
|
||||
summary: str
|
||||
objects: List[Object]
|
||||
scene: str
|
||||
colors: List[str]
|
||||
time_of_day: Literal['Morning', 'Afternoon', 'Evening', 'Night']
|
||||
setting: Literal['Indoor', 'Outdoor', 'Unknown']
|
||||
text_content: Optional[str] = None
|
||||
|
||||
|
||||
# Get path from user input
|
||||
path = input('Enter the path to your image: ')
|
||||
path = Path(path)
|
||||
|
||||
# Verify the file exists
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f'Image not found at: {path}')
|
||||
|
||||
# Set up chat as usual
|
||||
response = chat(
|
||||
model='llama3.2-vision',
|
||||
format=ImageDescription.model_json_schema(), # Pass in the schema for the response
|
||||
messages=[
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Analyze this image and return a detailed JSON description including objects, scene, colors and any text detected. If you cannot determine certain details, leave those fields empty.',
|
||||
'images': [path],
|
||||
},
|
||||
],
|
||||
options={'temperature': 0}, # Set temperature to 0 for more deterministic output
|
||||
)
|
||||
|
||||
|
||||
# Convert received content to the schema
|
||||
image_analysis = ImageDescription.model_validate_json(response.message.content)
|
||||
print(image_analysis)
|
||||
@@ -0,0 +1,26 @@
|
||||
from ollama import chat
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# Define the schema for the response
|
||||
class FriendInfo(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
is_available: bool
|
||||
|
||||
|
||||
class FriendList(BaseModel):
|
||||
friends: list[FriendInfo]
|
||||
|
||||
|
||||
# schema = {'type': 'object', 'properties': {'friends': {'type': 'array', 'items': {'type': 'object', 'properties': {'name': {'type': 'string'}, 'age': {'type': 'integer'}, 'is_available': {'type': 'boolean'}}, 'required': ['name', 'age', 'is_available']}}}, 'required': ['friends']}
|
||||
response = chat(
|
||||
model='llama3.1:8b',
|
||||
messages=[{'role': 'user', 'content': 'I have two friends. The first is Ollama 22 years old busy saving the world, and the second is Alonso 23 years old and wants to hang out. Return a list of friends in JSON format'}],
|
||||
format=FriendList.model_json_schema(), # Use Pydantic to generate the schema or format=schema
|
||||
options={'temperature': 0}, # Make responses more deterministic
|
||||
)
|
||||
|
||||
# Use Pydantic to validate the response
|
||||
friends_response = FriendList.model_validate_json(response.message.content)
|
||||
print(friends_response)
|
||||
+18
-4
@@ -40,8 +40,8 @@ subtract_two_numbers_tool = {
|
||||
},
|
||||
}
|
||||
|
||||
prompt = 'What is three plus one?'
|
||||
print('Prompt:', prompt)
|
||||
messages = [{'role': 'user', 'content': 'What is three plus one?'}]
|
||||
print('Prompt:', messages[0]['content'])
|
||||
|
||||
available_functions = {
|
||||
'add_two_numbers': add_two_numbers,
|
||||
@@ -50,7 +50,7 @@ available_functions = {
|
||||
|
||||
response: ChatResponse = chat(
|
||||
'llama3.1',
|
||||
messages=[{'role': 'user', 'content': prompt}],
|
||||
messages=messages,
|
||||
tools=[add_two_numbers, subtract_two_numbers_tool],
|
||||
)
|
||||
|
||||
@@ -61,6 +61,20 @@ if response.message.tool_calls:
|
||||
if function_to_call := available_functions.get(tool.function.name):
|
||||
print('Calling function:', tool.function.name)
|
||||
print('Arguments:', tool.function.arguments)
|
||||
print('Function output:', function_to_call(**tool.function.arguments))
|
||||
output = function_to_call(**tool.function.arguments)
|
||||
print('Function output:', output)
|
||||
else:
|
||||
print('Function', tool.function.name, 'not found')
|
||||
|
||||
# Only needed to chat with the model using the tool call results
|
||||
if response.message.tool_calls:
|
||||
# Add the function response to messages for the model to use
|
||||
messages.append(response.message)
|
||||
messages.append({'role': 'tool', 'content': str(output), 'name': tool.function.name})
|
||||
|
||||
# Get final response from model with function outputs
|
||||
final_response = chat('llama3.1', messages=messages)
|
||||
print('Final response:', final_response.message.content)
|
||||
|
||||
else:
|
||||
print('No tool calls returned from model')
|
||||
|
||||
+18
-17
@@ -23,6 +23,8 @@ from typing import (
|
||||
|
||||
import sys
|
||||
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
|
||||
|
||||
from ollama._utils import convert_function_to_tool
|
||||
|
||||
@@ -186,7 +188,7 @@ class Client(BaseClient):
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: Literal[False] = False,
|
||||
raw: bool = False,
|
||||
format: Optional[Literal['', 'json']] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@@ -204,7 +206,7 @@ class Client(BaseClient):
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: Literal[True] = True,
|
||||
raw: bool = False,
|
||||
format: Optional[Literal['', 'json']] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@@ -221,7 +223,7 @@ class Client(BaseClient):
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: bool = False,
|
||||
raw: Optional[bool] = None,
|
||||
format: Optional[Literal['', 'json']] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@@ -263,9 +265,9 @@ class Client(BaseClient):
|
||||
model: str = '',
|
||||
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
|
||||
*,
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: Literal[False] = False,
|
||||
format: Optional[Literal['', 'json']] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> ChatResponse: ...
|
||||
@@ -276,9 +278,9 @@ class Client(BaseClient):
|
||||
model: str = '',
|
||||
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
|
||||
*,
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: Literal[True] = True,
|
||||
format: Optional[Literal['', 'json']] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Iterator[ChatResponse]: ...
|
||||
@@ -290,7 +292,7 @@ class Client(BaseClient):
|
||||
*,
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: bool = False,
|
||||
format: Optional[Literal['', 'json']] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Union[ChatResponse, Iterator[ChatResponse]]:
|
||||
@@ -327,7 +329,6 @@ class Client(BaseClient):
|
||||
|
||||
Returns `ChatResponse` if `stream` is `False`, otherwise returns a `ChatResponse` generator.
|
||||
"""
|
||||
|
||||
return self._request(
|
||||
ChatResponse,
|
||||
'POST',
|
||||
@@ -689,7 +690,7 @@ class AsyncClient(BaseClient):
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: Literal[False] = False,
|
||||
raw: bool = False,
|
||||
format: Optional[Literal['', 'json']] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@@ -707,7 +708,7 @@ class AsyncClient(BaseClient):
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: Literal[True] = True,
|
||||
raw: bool = False,
|
||||
format: Optional[Literal['', 'json']] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@@ -724,7 +725,7 @@ class AsyncClient(BaseClient):
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: bool = False,
|
||||
raw: Optional[bool] = None,
|
||||
format: Optional[Literal['', 'json']] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
@@ -765,9 +766,9 @@ class AsyncClient(BaseClient):
|
||||
model: str = '',
|
||||
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
|
||||
*,
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: Literal[False] = False,
|
||||
format: Optional[Literal['', 'json']] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> ChatResponse: ...
|
||||
@@ -780,7 +781,7 @@ class AsyncClient(BaseClient):
|
||||
*,
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: Literal[True] = True,
|
||||
format: Optional[Literal['', 'json']] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> AsyncIterator[ChatResponse]: ...
|
||||
@@ -790,9 +791,9 @@ class AsyncClient(BaseClient):
|
||||
model: str = '',
|
||||
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
|
||||
*,
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
|
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||
stream: bool = False,
|
||||
format: Optional[Literal['', 'json']] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Union[ChatResponse, AsyncIterator[ChatResponse]]:
|
||||
|
||||
+83
-6
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Any, Mapping, Optional, Union, Sequence
|
||||
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
from typing_extensions import Annotated, Literal
|
||||
|
||||
from pydantic import (
|
||||
@@ -17,16 +18,87 @@ from pydantic import (
|
||||
|
||||
class SubscriptableBaseModel(BaseModel):
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return getattr(self, key)
|
||||
"""
|
||||
>>> msg = Message(role='user')
|
||||
>>> msg['role']
|
||||
'user'
|
||||
>>> msg = Message(role='user')
|
||||
>>> msg['nonexistent']
|
||||
Traceback (most recent call last):
|
||||
KeyError: 'nonexistent'
|
||||
"""
|
||||
if key in self:
|
||||
return getattr(self, key)
|
||||
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
"""
|
||||
>>> msg = Message(role='user')
|
||||
>>> msg['role'] = 'assistant'
|
||||
>>> msg['role']
|
||||
'assistant'
|
||||
>>> tool_call = Message.ToolCall(function=Message.ToolCall.Function(name='foo', arguments={}))
|
||||
>>> msg = Message(role='user', content='hello')
|
||||
>>> msg['tool_calls'] = [tool_call]
|
||||
>>> msg['tool_calls'][0]['function']['name']
|
||||
'foo'
|
||||
"""
|
||||
setattr(self, key, value)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return hasattr(self, key)
|
||||
"""
|
||||
>>> msg = Message(role='user')
|
||||
>>> 'nonexistent' in msg
|
||||
False
|
||||
>>> 'role' in msg
|
||||
True
|
||||
>>> 'content' in msg
|
||||
False
|
||||
>>> msg.content = 'hello!'
|
||||
>>> 'content' in msg
|
||||
True
|
||||
>>> msg = Message(role='user', content='hello!')
|
||||
>>> 'content' in msg
|
||||
True
|
||||
>>> 'tool_calls' in msg
|
||||
False
|
||||
>>> msg['tool_calls'] = []
|
||||
>>> 'tool_calls' in msg
|
||||
True
|
||||
>>> msg['tool_calls'] = [Message.ToolCall(function=Message.ToolCall.Function(name='foo', arguments={}))]
|
||||
>>> 'tool_calls' in msg
|
||||
True
|
||||
>>> msg['tool_calls'] = None
|
||||
>>> 'tool_calls' in msg
|
||||
True
|
||||
>>> tool = Tool()
|
||||
>>> 'type' in tool
|
||||
True
|
||||
"""
|
||||
if key in self.model_fields_set:
|
||||
return True
|
||||
|
||||
if key in self.model_fields:
|
||||
return self.model_fields[key].default is not None
|
||||
|
||||
return False
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return getattr(self, key, default)
|
||||
"""
|
||||
>>> msg = Message(role='user')
|
||||
>>> msg.get('role')
|
||||
'user'
|
||||
>>> msg = Message(role='user')
|
||||
>>> msg.get('nonexistent')
|
||||
>>> msg = Message(role='user')
|
||||
>>> msg.get('nonexistent', 'default')
|
||||
'default'
|
||||
>>> msg = Message(role='user', tool_calls=[ Message.ToolCall(function=Message.ToolCall.Function(name='foo', arguments={}))])
|
||||
>>> msg.get('tool_calls')[0]['function']['name']
|
||||
'foo'
|
||||
"""
|
||||
return self[key] if key in self else default
|
||||
|
||||
|
||||
class Options(SubscriptableBaseModel):
|
||||
@@ -79,7 +151,7 @@ class BaseGenerateRequest(BaseStreamableRequest):
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None
|
||||
'Options to use for the request.'
|
||||
|
||||
format: Optional[Literal['', 'json']] = None
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None
|
||||
'Format of the response.'
|
||||
|
||||
keep_alive: Optional[Union[float, str]] = None
|
||||
@@ -95,9 +167,14 @@ class Image(BaseModel):
|
||||
return b64encode(self.value.read_bytes() if isinstance(self.value, Path) else self.value).decode()
|
||||
|
||||
if isinstance(self.value, str):
|
||||
if Path(self.value).exists():
|
||||
return b64encode(Path(self.value).read_bytes()).decode()
|
||||
try:
|
||||
if Path(self.value).exists():
|
||||
return b64encode(Path(self.value).read_bytes()).decode()
|
||||
except Exception:
|
||||
# Long base64 string can't be wrapped in Path, so try to treat as base64 string
|
||||
pass
|
||||
|
||||
# String might be a file path, but might not exist
|
||||
if self.value.split('.')[-1] in ('png', 'jpg', 'jpeg', 'webp'):
|
||||
raise ValueError(f'File {self.value} does not exist')
|
||||
|
||||
|
||||
Generated
+3
-3
@@ -559,13 +559,13 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.3.3"
|
||||
version = "8.3.4"
|
||||
description = "pytest: simple powerful testing with Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2"},
|
||||
{file = "pytest-8.3.3.tar.gz", hash = "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181"},
|
||||
{file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"},
|
||||
{file = "pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
||||
+225
-1
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
from pydantic import ValidationError
|
||||
from pydantic import ValidationError, BaseModel
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
@@ -122,6 +122,128 @@ def test_client_chat_images(httpserver: HTTPServer):
|
||||
assert response['message']['content'] == "I don't know."
|
||||
|
||||
|
||||
def test_client_chat_format_json(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'tools': [],
|
||||
'format': 'json',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': '{"answer": "Because of Rayleigh scattering"}',
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format='json')
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['message']['role'] == 'assistant'
|
||||
assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering"}'
|
||||
|
||||
|
||||
def test_client_chat_format_pydantic(httpserver: HTTPServer):
|
||||
class ResponseFormat(BaseModel):
|
||||
answer: str
|
||||
confidence: float
|
||||
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'tools': [],
|
||||
'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']},
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}',
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format=ResponseFormat.model_json_schema())
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['message']['role'] == 'assistant'
|
||||
assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_chat_format_json(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'tools': [],
|
||||
'format': 'json',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': '{"answer": "Because of Rayleigh scattering"}',
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format='json')
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['message']['role'] == 'assistant'
|
||||
assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering"}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_chat_format_pydantic(httpserver: HTTPServer):
|
||||
class ResponseFormat(BaseModel):
|
||||
answer: str
|
||||
confidence: float
|
||||
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'tools': [],
|
||||
'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']},
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}',
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format=ResponseFormat.model_json_schema())
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['message']['role'] == 'assistant'
|
||||
assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
|
||||
|
||||
|
||||
def test_client_generate(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
@@ -205,6 +327,108 @@ def test_client_generate_images(httpserver: HTTPServer):
|
||||
assert response['response'] == 'Because it is.'
|
||||
|
||||
|
||||
def test_client_generate_format_json(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'format': 'json',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'response': '{"answer": "Because of Rayleigh scattering"}',
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy', 'Why is the sky blue?', format='json')
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['response'] == '{"answer": "Because of Rayleigh scattering"}'
|
||||
|
||||
|
||||
def test_client_generate_format_pydantic(httpserver: HTTPServer):
|
||||
class ResponseFormat(BaseModel):
|
||||
answer: str
|
||||
confidence: float
|
||||
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']},
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'response': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}',
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy', 'Why is the sky blue?', format=ResponseFormat.model_json_schema())
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['response'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_generate_format_json(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'format': 'json',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'response': '{"answer": "Because of Rayleigh scattering"}',
|
||||
}
|
||||
)
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.generate('dummy', 'Why is the sky blue?', format='json')
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['response'] == '{"answer": "Because of Rayleigh scattering"}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_generate_format_pydantic(httpserver: HTTPServer):
|
||||
class ResponseFormat(BaseModel):
|
||||
answer: str
|
||||
confidence: float
|
||||
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']},
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'response': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}',
|
||||
}
|
||||
)
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.generate('dummy', 'Why is the sky blue?', format=ResponseFormat.model_json_schema())
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['response'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
|
||||
|
||||
|
||||
def test_client_pull(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/pull',
|
||||
|
||||
@@ -19,6 +19,12 @@ def test_image_serialization_base64_string():
|
||||
assert img.model_dump() == b64_str # Should return as-is if valid base64
|
||||
|
||||
|
||||
def test_image_serialization_long_base64_string():
|
||||
b64_str = 'dGVzdCBiYXNlNjQgc3RyaW5n' * 1000
|
||||
img = Image(value=b64_str)
|
||||
assert img.model_dump() == b64_str # Should return as-is if valid base64
|
||||
|
||||
|
||||
def test_image_serialization_plain_string():
|
||||
img = Image(value='not a path or base64')
|
||||
assert img.model_dump() == 'not a path or base64' # Should return as-is
|
||||
|
||||
Reference in New Issue
Block a user