mirror of
https://github.com/ollama/ollama-python.git
synced 2026-05-03 12:52:35 +00:00
Structured outputs support with examples (#354)
This commit is contained in:
+225
-1
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
from pydantic import ValidationError
|
||||
from pydantic import ValidationError, BaseModel
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
@@ -122,6 +122,128 @@ def test_client_chat_images(httpserver: HTTPServer):
|
||||
assert response['message']['content'] == "I don't know."
|
||||
|
||||
|
||||
def test_client_chat_format_json(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'tools': [],
|
||||
'format': 'json',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': '{"answer": "Because of Rayleigh scattering"}',
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format='json')
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['message']['role'] == 'assistant'
|
||||
assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering"}'
|
||||
|
||||
|
||||
def test_client_chat_format_pydantic(httpserver: HTTPServer):
|
||||
class ResponseFormat(BaseModel):
|
||||
answer: str
|
||||
confidence: float
|
||||
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'tools': [],
|
||||
'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']},
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}',
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format=ResponseFormat.model_json_schema())
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['message']['role'] == 'assistant'
|
||||
assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_chat_format_json(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'tools': [],
|
||||
'format': 'json',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': '{"answer": "Because of Rayleigh scattering"}',
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format='json')
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['message']['role'] == 'assistant'
|
||||
assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering"}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_chat_format_pydantic(httpserver: HTTPServer):
|
||||
class ResponseFormat(BaseModel):
|
||||
answer: str
|
||||
confidence: float
|
||||
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'tools': [],
|
||||
'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']},
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}',
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format=ResponseFormat.model_json_schema())
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['message']['role'] == 'assistant'
|
||||
assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
|
||||
|
||||
|
||||
def test_client_generate(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
@@ -205,6 +327,108 @@ def test_client_generate_images(httpserver: HTTPServer):
|
||||
assert response['response'] == 'Because it is.'
|
||||
|
||||
|
||||
def test_client_generate_format_json(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'format': 'json',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'response': '{"answer": "Because of Rayleigh scattering"}',
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy', 'Why is the sky blue?', format='json')
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['response'] == '{"answer": "Because of Rayleigh scattering"}'
|
||||
|
||||
|
||||
def test_client_generate_format_pydantic(httpserver: HTTPServer):
|
||||
class ResponseFormat(BaseModel):
|
||||
answer: str
|
||||
confidence: float
|
||||
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']},
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'response': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}',
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy', 'Why is the sky blue?', format=ResponseFormat.model_json_schema())
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['response'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_generate_format_json(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'format': 'json',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'response': '{"answer": "Because of Rayleigh scattering"}',
|
||||
}
|
||||
)
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.generate('dummy', 'Why is the sky blue?', format='json')
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['response'] == '{"answer": "Because of Rayleigh scattering"}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_generate_format_pydantic(httpserver: HTTPServer):
|
||||
class ResponseFormat(BaseModel):
|
||||
answer: str
|
||||
confidence: float
|
||||
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']},
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'response': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}',
|
||||
}
|
||||
)
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.generate('dummy', 'Why is the sky blue?', format=ResponseFormat.model_json_schema())
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['response'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
|
||||
|
||||
|
||||
def test_client_pull(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/pull',
|
||||
|
||||
Reference in New Issue
Block a user