mirror of
https://github.com/ollama/ollama-python.git
synced 2026-05-03 12:52:35 +00:00
Add image generation support (#616)
This commit is contained in:
@@ -568,6 +568,115 @@ async def test_async_client_generate_format_pydantic(httpserver: HTTPServer):
|
||||
assert response['response'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
|
||||
|
||||
|
||||
def test_client_generate_image(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy-image',
|
||||
'prompt': 'a sunset over mountains',
|
||||
'stream': False,
|
||||
'width': 1024,
|
||||
'height': 768,
|
||||
'steps': 20,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy-image',
|
||||
'image': PNG_BASE64,
|
||||
'done': True,
|
||||
'done_reason': 'stop',
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy-image', 'a sunset over mountains', width=1024, height=768, steps=20)
|
||||
assert response['model'] == 'dummy-image'
|
||||
assert response['image'] == PNG_BASE64
|
||||
assert response['done'] is True
|
||||
|
||||
|
||||
def test_client_generate_image_stream(httpserver: HTTPServer):
|
||||
def stream_handler(_: Request):
|
||||
def generate():
|
||||
# Progress updates
|
||||
for i in range(1, 4):
|
||||
yield (
|
||||
json.dumps(
|
||||
{
|
||||
'model': 'dummy-image',
|
||||
'completed': i,
|
||||
'total': 3,
|
||||
'done': False,
|
||||
}
|
||||
)
|
||||
+ '\n'
|
||||
)
|
||||
# Final response with image
|
||||
yield (
|
||||
json.dumps(
|
||||
{
|
||||
'model': 'dummy-image',
|
||||
'image': PNG_BASE64,
|
||||
'done': True,
|
||||
'done_reason': 'stop',
|
||||
}
|
||||
)
|
||||
+ '\n'
|
||||
)
|
||||
|
||||
return Response(generate())
|
||||
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy-image',
|
||||
'prompt': 'a sunset over mountains',
|
||||
'stream': True,
|
||||
'width': 512,
|
||||
'height': 512,
|
||||
},
|
||||
).respond_with_handler(stream_handler)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy-image', 'a sunset over mountains', stream=True, width=512, height=512)
|
||||
|
||||
parts = list(response)
|
||||
# Check progress updates
|
||||
assert parts[0]['completed'] == 1
|
||||
assert parts[0]['total'] == 3
|
||||
assert parts[0]['done'] is False
|
||||
# Check final response
|
||||
assert parts[-1]['image'] == PNG_BASE64
|
||||
assert parts[-1]['done'] is True
|
||||
|
||||
|
||||
async def test_async_client_generate_image(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy-image',
|
||||
'prompt': 'a robot painting',
|
||||
'stream': False,
|
||||
'width': 1024,
|
||||
'height': 1024,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy-image',
|
||||
'image': PNG_BASE64,
|
||||
'done': True,
|
||||
}
|
||||
)
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.generate('dummy-image', 'a robot painting', width=1024, height=1024)
|
||||
assert response['model'] == 'dummy-image'
|
||||
assert response['image'] == PNG_BASE64
|
||||
|
||||
|
||||
def test_client_pull(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/pull',
|
||||
|
||||
Reference in New Issue
Block a user