mirror of
https://github.com/ollama/ollama-python.git
synced 2026-01-13 21:57:16 +08:00
Structured outputs support with examples (#354)
This commit is contained in:
parent
e956a331e8
commit
4b10dee2b2
@ -37,9 +37,6 @@ See [_types.py](ollama/_types.py) for more information on the response types.
|
|||||||
|
|
||||||
Response streaming can be enabled by setting `stream=True`.
|
Response streaming can be enabled by setting `stream=True`.
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> Streaming Tool/Function calling is not yet supported.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from ollama import chat
|
from ollama import chat
|
||||||
|
|
||||||
|
|||||||
@ -30,6 +30,12 @@ python3 examples/<example>.py
|
|||||||
- [multimodal_generate.py](multimodal_generate.py)
|
- [multimodal_generate.py](multimodal_generate.py)
|
||||||
|
|
||||||
|
|
||||||
|
### Structured Outputs - Generate structured outputs with a model
|
||||||
|
- [structured-outputs.py](structured-outputs.py)
|
||||||
|
- [async-structured-outputs.py](async-structured-outputs.py)
|
||||||
|
- [structured-outputs-image.py](structured-outputs-image.py)
|
||||||
|
|
||||||
|
|
||||||
### Ollama List - List all downloaded models and their properties
|
### Ollama List - List all downloaded models and their properties
|
||||||
- [list.py](list.py)
|
- [list.py](list.py)
|
||||||
|
|
||||||
|
|||||||
32
examples/async-structured-outputs.py
Normal file
32
examples/async-structured-outputs.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from ollama import AsyncClient
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
# Define the schema for the response
|
||||||
|
class FriendInfo(BaseModel):
|
||||||
|
name: str
|
||||||
|
age: int
|
||||||
|
is_available: bool
|
||||||
|
|
||||||
|
|
||||||
|
class FriendList(BaseModel):
|
||||||
|
friends: list[FriendInfo]
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
client = AsyncClient()
|
||||||
|
response = await client.chat(
|
||||||
|
model='llama3.1:8b',
|
||||||
|
messages=[{'role': 'user', 'content': 'I have two friends. The first is Ollama 22 years old busy saving the world, and the second is Alonso 23 years old and wants to hang out. Return a list of friends in JSON format'}],
|
||||||
|
format=FriendList.model_json_schema(), # Use Pydantic to generate the schema
|
||||||
|
options={'temperature': 0}, # Make responses more deterministic
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use Pydantic to validate the response
|
||||||
|
friends_response = FriendList.model_validate_json(response.message.content)
|
||||||
|
print(friends_response)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(main())
|
||||||
50
examples/structured-outputs-image.py
Normal file
50
examples/structured-outputs-image.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Optional, Literal
|
||||||
|
from ollama import chat
|
||||||
|
from rich import print
|
||||||
|
|
||||||
|
|
||||||
|
# Define the schema for image objects
|
||||||
|
class Object(BaseModel):
|
||||||
|
name: str
|
||||||
|
confidence: float
|
||||||
|
attributes: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ImageDescription(BaseModel):
|
||||||
|
summary: str
|
||||||
|
objects: List[Object]
|
||||||
|
scene: str
|
||||||
|
colors: List[str]
|
||||||
|
time_of_day: Literal['Morning', 'Afternoon', 'Evening', 'Night']
|
||||||
|
setting: Literal['Indoor', 'Outdoor', 'Unknown']
|
||||||
|
text_content: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Get path from user input
|
||||||
|
path = input('Enter the path to your image: ')
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
# Verify the file exists
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f'Image not found at: {path}')
|
||||||
|
|
||||||
|
# Set up chat as usual
|
||||||
|
response = chat(
|
||||||
|
model='llama3.2-vision',
|
||||||
|
format=ImageDescription.model_json_schema(), # Pass in the schema for the response
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
'role': 'user',
|
||||||
|
'content': 'Analyze this image and return a detailed JSON description including objects, scene, colors and any text detected. If you cannot determine certain details, leave those fields empty.',
|
||||||
|
'images': [path],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
options={'temperature': 0}, # Set temperature to 0 for more deterministic output
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Convert received content to the schema
|
||||||
|
image_analysis = ImageDescription.model_validate_json(response.message.content)
|
||||||
|
print(image_analysis)
|
||||||
26
examples/structured-outputs.py
Normal file
26
examples/structured-outputs.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from ollama import chat
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
# Define the schema for the response
|
||||||
|
class FriendInfo(BaseModel):
|
||||||
|
name: str
|
||||||
|
age: int
|
||||||
|
is_available: bool
|
||||||
|
|
||||||
|
|
||||||
|
class FriendList(BaseModel):
|
||||||
|
friends: list[FriendInfo]
|
||||||
|
|
||||||
|
|
||||||
|
# schema = {'type': 'object', 'properties': {'friends': {'type': 'array', 'items': {'type': 'object', 'properties': {'name': {'type': 'string'}, 'age': {'type': 'integer'}, 'is_available': {'type': 'boolean'}}, 'required': ['name', 'age', 'is_available']}}}, 'required': ['friends']}
|
||||||
|
response = chat(
|
||||||
|
model='llama3.1:8b',
|
||||||
|
messages=[{'role': 'user', 'content': 'I have two friends. The first is Ollama 22 years old busy saving the world, and the second is Alonso 23 years old and wants to hang out. Return a list of friends in JSON format'}],
|
||||||
|
format=FriendList.model_json_schema(), # Use Pydantic to generate the schema or format=schema
|
||||||
|
options={'temperature': 0}, # Make responses more deterministic
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use Pydantic to validate the response
|
||||||
|
friends_response = FriendList.model_validate_json(response.message.content)
|
||||||
|
print(friends_response)
|
||||||
@ -23,6 +23,8 @@ from typing import (
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from pydantic.json_schema import JsonSchemaValue
|
||||||
|
|
||||||
|
|
||||||
from ollama._utils import convert_function_to_tool
|
from ollama._utils import convert_function_to_tool
|
||||||
|
|
||||||
@ -186,7 +188,7 @@ class Client(BaseClient):
|
|||||||
context: Optional[Sequence[int]] = None,
|
context: Optional[Sequence[int]] = None,
|
||||||
stream: Literal[False] = False,
|
stream: Literal[False] = False,
|
||||||
raw: bool = False,
|
raw: bool = False,
|
||||||
format: Optional[Literal['', 'json']] = None,
|
format: Optional[Union[Literal['json'], JsonSchemaValue]] = None,
|
||||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
@ -204,7 +206,7 @@ class Client(BaseClient):
|
|||||||
context: Optional[Sequence[int]] = None,
|
context: Optional[Sequence[int]] = None,
|
||||||
stream: Literal[True] = True,
|
stream: Literal[True] = True,
|
||||||
raw: bool = False,
|
raw: bool = False,
|
||||||
format: Optional[Literal['', 'json']] = None,
|
format: Optional[Union[Literal['json'], JsonSchemaValue]] = None,
|
||||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
@ -221,7 +223,7 @@ class Client(BaseClient):
|
|||||||
context: Optional[Sequence[int]] = None,
|
context: Optional[Sequence[int]] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
raw: Optional[bool] = None,
|
raw: Optional[bool] = None,
|
||||||
format: Optional[Literal['', 'json']] = None,
|
format: Optional[Union[Literal['json'], JsonSchemaValue]] = None,
|
||||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
@ -265,7 +267,7 @@ class Client(BaseClient):
|
|||||||
*,
|
*,
|
||||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||||
stream: Literal[False] = False,
|
stream: Literal[False] = False,
|
||||||
format: Optional[Literal['', 'json']] = None,
|
format: Optional[Union[Literal['json'], JsonSchemaValue]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
) -> ChatResponse: ...
|
) -> ChatResponse: ...
|
||||||
@ -278,7 +280,7 @@ class Client(BaseClient):
|
|||||||
*,
|
*,
|
||||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||||
stream: Literal[True] = True,
|
stream: Literal[True] = True,
|
||||||
format: Optional[Literal['', 'json']] = None,
|
format: Optional[Union[Literal['json'], JsonSchemaValue]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
) -> Iterator[ChatResponse]: ...
|
) -> Iterator[ChatResponse]: ...
|
||||||
@ -290,7 +292,7 @@ class Client(BaseClient):
|
|||||||
*,
|
*,
|
||||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
format: Optional[Literal['', 'json']] = None,
|
format: Optional[Union[Literal['json'], JsonSchemaValue]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
) -> Union[ChatResponse, Iterator[ChatResponse]]:
|
) -> Union[ChatResponse, Iterator[ChatResponse]]:
|
||||||
@ -327,7 +329,6 @@ class Client(BaseClient):
|
|||||||
|
|
||||||
Returns `ChatResponse` if `stream` is `False`, otherwise returns a `ChatResponse` generator.
|
Returns `ChatResponse` if `stream` is `False`, otherwise returns a `ChatResponse` generator.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._request(
|
return self._request(
|
||||||
ChatResponse,
|
ChatResponse,
|
||||||
'POST',
|
'POST',
|
||||||
@ -689,7 +690,7 @@ class AsyncClient(BaseClient):
|
|||||||
context: Optional[Sequence[int]] = None,
|
context: Optional[Sequence[int]] = None,
|
||||||
stream: Literal[False] = False,
|
stream: Literal[False] = False,
|
||||||
raw: bool = False,
|
raw: bool = False,
|
||||||
format: Optional[Literal['', 'json']] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
@ -707,7 +708,7 @@ class AsyncClient(BaseClient):
|
|||||||
context: Optional[Sequence[int]] = None,
|
context: Optional[Sequence[int]] = None,
|
||||||
stream: Literal[True] = True,
|
stream: Literal[True] = True,
|
||||||
raw: bool = False,
|
raw: bool = False,
|
||||||
format: Optional[Literal['', 'json']] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
@ -724,7 +725,7 @@ class AsyncClient(BaseClient):
|
|||||||
context: Optional[Sequence[int]] = None,
|
context: Optional[Sequence[int]] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
raw: Optional[bool] = None,
|
raw: Optional[bool] = None,
|
||||||
format: Optional[Literal['', 'json']] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
@ -767,7 +768,7 @@ class AsyncClient(BaseClient):
|
|||||||
*,
|
*,
|
||||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||||
stream: Literal[False] = False,
|
stream: Literal[False] = False,
|
||||||
format: Optional[Literal['', 'json']] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
) -> ChatResponse: ...
|
) -> ChatResponse: ...
|
||||||
@ -780,7 +781,7 @@ class AsyncClient(BaseClient):
|
|||||||
*,
|
*,
|
||||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||||
stream: Literal[True] = True,
|
stream: Literal[True] = True,
|
||||||
format: Optional[Literal['', 'json']] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
) -> AsyncIterator[ChatResponse]: ...
|
) -> AsyncIterator[ChatResponse]: ...
|
||||||
@ -792,7 +793,7 @@ class AsyncClient(BaseClient):
|
|||||||
*,
|
*,
|
||||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
format: Optional[Literal['', 'json']] = None,
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
) -> Union[ChatResponse, AsyncIterator[ChatResponse]]:
|
) -> Union[ChatResponse, AsyncIterator[ChatResponse]]:
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from pathlib import Path
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Mapping, Optional, Union, Sequence
|
from typing import Any, Mapping, Optional, Union, Sequence
|
||||||
|
|
||||||
|
from pydantic.json_schema import JsonSchemaValue
|
||||||
from typing_extensions import Annotated, Literal
|
from typing_extensions import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
@ -150,7 +151,7 @@ class BaseGenerateRequest(BaseStreamableRequest):
|
|||||||
options: Optional[Union[Mapping[str, Any], Options]] = None
|
options: Optional[Union[Mapping[str, Any], Options]] = None
|
||||||
'Options to use for the request.'
|
'Options to use for the request.'
|
||||||
|
|
||||||
format: Optional[Literal['', 'json']] = None
|
format: Optional[Union[Literal['json'], JsonSchemaValue]] = None
|
||||||
'Format of the response.'
|
'Format of the response.'
|
||||||
|
|
||||||
keep_alive: Optional[Union[float, str]] = None
|
keep_alive: Optional[Union[float, str]] = None
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError, BaseModel
|
||||||
import pytest
|
import pytest
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -122,6 +122,128 @@ def test_client_chat_images(httpserver: HTTPServer):
|
|||||||
assert response['message']['content'] == "I don't know."
|
assert response['message']['content'] == "I don't know."
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_chat_format_json(httpserver: HTTPServer):
|
||||||
|
httpserver.expect_ordered_request(
|
||||||
|
'/api/chat',
|
||||||
|
method='POST',
|
||||||
|
json={
|
||||||
|
'model': 'dummy',
|
||||||
|
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||||
|
'tools': [],
|
||||||
|
'format': 'json',
|
||||||
|
'stream': False,
|
||||||
|
},
|
||||||
|
).respond_with_json(
|
||||||
|
{
|
||||||
|
'model': 'dummy',
|
||||||
|
'message': {
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': '{"answer": "Because of Rayleigh scattering"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
client = Client(httpserver.url_for('/'))
|
||||||
|
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format='json')
|
||||||
|
assert response['model'] == 'dummy'
|
||||||
|
assert response['message']['role'] == 'assistant'
|
||||||
|
assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering"}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_chat_format_pydantic(httpserver: HTTPServer):
|
||||||
|
class ResponseFormat(BaseModel):
|
||||||
|
answer: str
|
||||||
|
confidence: float
|
||||||
|
|
||||||
|
httpserver.expect_ordered_request(
|
||||||
|
'/api/chat',
|
||||||
|
method='POST',
|
||||||
|
json={
|
||||||
|
'model': 'dummy',
|
||||||
|
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||||
|
'tools': [],
|
||||||
|
'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']},
|
||||||
|
'stream': False,
|
||||||
|
},
|
||||||
|
).respond_with_json(
|
||||||
|
{
|
||||||
|
'model': 'dummy',
|
||||||
|
'message': {
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
client = Client(httpserver.url_for('/'))
|
||||||
|
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format=ResponseFormat.model_json_schema())
|
||||||
|
assert response['model'] == 'dummy'
|
||||||
|
assert response['message']['role'] == 'assistant'
|
||||||
|
assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_client_chat_format_json(httpserver: HTTPServer):
|
||||||
|
httpserver.expect_ordered_request(
|
||||||
|
'/api/chat',
|
||||||
|
method='POST',
|
||||||
|
json={
|
||||||
|
'model': 'dummy',
|
||||||
|
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||||
|
'tools': [],
|
||||||
|
'format': 'json',
|
||||||
|
'stream': False,
|
||||||
|
},
|
||||||
|
).respond_with_json(
|
||||||
|
{
|
||||||
|
'model': 'dummy',
|
||||||
|
'message': {
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': '{"answer": "Because of Rayleigh scattering"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
client = AsyncClient(httpserver.url_for('/'))
|
||||||
|
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format='json')
|
||||||
|
assert response['model'] == 'dummy'
|
||||||
|
assert response['message']['role'] == 'assistant'
|
||||||
|
assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering"}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_client_chat_format_pydantic(httpserver: HTTPServer):
|
||||||
|
class ResponseFormat(BaseModel):
|
||||||
|
answer: str
|
||||||
|
confidence: float
|
||||||
|
|
||||||
|
httpserver.expect_ordered_request(
|
||||||
|
'/api/chat',
|
||||||
|
method='POST',
|
||||||
|
json={
|
||||||
|
'model': 'dummy',
|
||||||
|
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||||
|
'tools': [],
|
||||||
|
'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']},
|
||||||
|
'stream': False,
|
||||||
|
},
|
||||||
|
).respond_with_json(
|
||||||
|
{
|
||||||
|
'model': 'dummy',
|
||||||
|
'message': {
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
client = AsyncClient(httpserver.url_for('/'))
|
||||||
|
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format=ResponseFormat.model_json_schema())
|
||||||
|
assert response['model'] == 'dummy'
|
||||||
|
assert response['message']['role'] == 'assistant'
|
||||||
|
assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
|
||||||
|
|
||||||
|
|
||||||
def test_client_generate(httpserver: HTTPServer):
|
def test_client_generate(httpserver: HTTPServer):
|
||||||
httpserver.expect_ordered_request(
|
httpserver.expect_ordered_request(
|
||||||
'/api/generate',
|
'/api/generate',
|
||||||
@ -205,6 +327,108 @@ def test_client_generate_images(httpserver: HTTPServer):
|
|||||||
assert response['response'] == 'Because it is.'
|
assert response['response'] == 'Because it is.'
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_generate_format_json(httpserver: HTTPServer):
|
||||||
|
httpserver.expect_ordered_request(
|
||||||
|
'/api/generate',
|
||||||
|
method='POST',
|
||||||
|
json={
|
||||||
|
'model': 'dummy',
|
||||||
|
'prompt': 'Why is the sky blue?',
|
||||||
|
'format': 'json',
|
||||||
|
'stream': False,
|
||||||
|
},
|
||||||
|
).respond_with_json(
|
||||||
|
{
|
||||||
|
'model': 'dummy',
|
||||||
|
'response': '{"answer": "Because of Rayleigh scattering"}',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
client = Client(httpserver.url_for('/'))
|
||||||
|
response = client.generate('dummy', 'Why is the sky blue?', format='json')
|
||||||
|
assert response['model'] == 'dummy'
|
||||||
|
assert response['response'] == '{"answer": "Because of Rayleigh scattering"}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_generate_format_pydantic(httpserver: HTTPServer):
|
||||||
|
class ResponseFormat(BaseModel):
|
||||||
|
answer: str
|
||||||
|
confidence: float
|
||||||
|
|
||||||
|
httpserver.expect_ordered_request(
|
||||||
|
'/api/generate',
|
||||||
|
method='POST',
|
||||||
|
json={
|
||||||
|
'model': 'dummy',
|
||||||
|
'prompt': 'Why is the sky blue?',
|
||||||
|
'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']},
|
||||||
|
'stream': False,
|
||||||
|
},
|
||||||
|
).respond_with_json(
|
||||||
|
{
|
||||||
|
'model': 'dummy',
|
||||||
|
'response': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
client = Client(httpserver.url_for('/'))
|
||||||
|
response = client.generate('dummy', 'Why is the sky blue?', format=ResponseFormat.model_json_schema())
|
||||||
|
assert response['model'] == 'dummy'
|
||||||
|
assert response['response'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_client_generate_format_json(httpserver: HTTPServer):
|
||||||
|
httpserver.expect_ordered_request(
|
||||||
|
'/api/generate',
|
||||||
|
method='POST',
|
||||||
|
json={
|
||||||
|
'model': 'dummy',
|
||||||
|
'prompt': 'Why is the sky blue?',
|
||||||
|
'format': 'json',
|
||||||
|
'stream': False,
|
||||||
|
},
|
||||||
|
).respond_with_json(
|
||||||
|
{
|
||||||
|
'model': 'dummy',
|
||||||
|
'response': '{"answer": "Because of Rayleigh scattering"}',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
client = AsyncClient(httpserver.url_for('/'))
|
||||||
|
response = await client.generate('dummy', 'Why is the sky blue?', format='json')
|
||||||
|
assert response['model'] == 'dummy'
|
||||||
|
assert response['response'] == '{"answer": "Because of Rayleigh scattering"}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_client_generate_format_pydantic(httpserver: HTTPServer):
|
||||||
|
class ResponseFormat(BaseModel):
|
||||||
|
answer: str
|
||||||
|
confidence: float
|
||||||
|
|
||||||
|
httpserver.expect_ordered_request(
|
||||||
|
'/api/generate',
|
||||||
|
method='POST',
|
||||||
|
json={
|
||||||
|
'model': 'dummy',
|
||||||
|
'prompt': 'Why is the sky blue?',
|
||||||
|
'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']},
|
||||||
|
'stream': False,
|
||||||
|
},
|
||||||
|
).respond_with_json(
|
||||||
|
{
|
||||||
|
'model': 'dummy',
|
||||||
|
'response': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
client = AsyncClient(httpserver.url_for('/'))
|
||||||
|
response = await client.generate('dummy', 'Why is the sky blue?', format=ResponseFormat.model_json_schema())
|
||||||
|
assert response['model'] == 'dummy'
|
||||||
|
assert response['response'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
|
||||||
|
|
||||||
|
|
||||||
def test_client_pull(httpserver: HTTPServer):
|
def test_client_pull(httpserver: HTTPServer):
|
||||||
httpserver.expect_ordered_request(
|
httpserver.expect_ordered_request(
|
||||||
'/api/pull',
|
'/api/pull',
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user