mirror of
https://github.com/ollama/ollama-python.git
synced 2026-05-03 12:52:35 +00:00
bugfix: fix passing Image type in messages for chat (#390)
--------- Co-authored-by: Aarni Koskela <akx@iki.fi>
This commit is contained in:
+39
-21
@@ -1,5 +1,5 @@
|
||||
import base64
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
from pydantic import ValidationError, BaseModel
|
||||
import pytest
|
||||
@@ -7,10 +7,12 @@ import tempfile
|
||||
from pathlib import Path
|
||||
from pytest_httpserver import HTTPServer, URIPattern
|
||||
from werkzeug.wrappers import Request, Response
|
||||
from PIL import Image
|
||||
|
||||
from ollama._client import Client, AsyncClient, _copy_tools
|
||||
|
||||
PNG_BASE64 = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'
|
||||
PNG_BYTES = base64.b64decode(PNG_BASE64)
|
||||
|
||||
|
||||
class PrefixPattern(URIPattern):
|
||||
def __init__(self, prefix: str):
|
||||
@@ -86,7 +88,11 @@ def test_client_chat_stream(httpserver: HTTPServer):
|
||||
assert part['message']['content'] == next(it)
|
||||
|
||||
|
||||
def test_client_chat_images(httpserver: HTTPServer):
|
||||
@pytest.mark.parametrize('message_format', ('dict', 'pydantic_model'))
|
||||
@pytest.mark.parametrize('file_style', ('path', 'bytes'))
|
||||
def test_client_chat_images(httpserver: HTTPServer, message_format: str, file_style: str, tmp_path):
|
||||
from ollama._types import Message, Image
|
||||
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
@@ -96,7 +102,7 @@ def test_client_chat_images(httpserver: HTTPServer):
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Why is the sky blue?',
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
'images': [PNG_BASE64],
|
||||
},
|
||||
],
|
||||
'tools': [],
|
||||
@@ -114,12 +120,24 @@ def test_client_chat_images(httpserver: HTTPServer):
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with io.BytesIO() as b:
|
||||
Image.new('RGB', (1, 1)).save(b, 'PNG')
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['message']['role'] == 'assistant'
|
||||
assert response['message']['content'] == "I don't know."
|
||||
if file_style == 'bytes':
|
||||
image_content = PNG_BYTES
|
||||
elif file_style == 'path':
|
||||
image_path = tmp_path / 'transparent.png'
|
||||
image_path.write_bytes(PNG_BYTES)
|
||||
image_content = str(image_path)
|
||||
|
||||
if message_format == 'pydantic_model':
|
||||
messages = [Message(role='user', content='Why is the sky blue?', images=[Image(value=image_content)])]
|
||||
elif message_format == 'dict':
|
||||
messages = [{'role': 'user', 'content': 'Why is the sky blue?', 'images': [image_content]}]
|
||||
else:
|
||||
raise ValueError(f'Invalid message format: {message_format}')
|
||||
|
||||
response = client.chat('dummy', messages=messages)
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['message']['role'] == 'assistant'
|
||||
assert response['message']['content'] == "I don't know."
|
||||
|
||||
|
||||
def test_client_chat_format_json(httpserver: HTTPServer):
|
||||
@@ -309,7 +327,7 @@ def test_client_generate_images(httpserver: HTTPServer):
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'stream': False,
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
'images': [PNG_BASE64],
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
@@ -321,7 +339,8 @@ def test_client_generate_images(httpserver: HTTPServer):
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
Image.new('RGB', (1, 1)).save(temp, 'PNG')
|
||||
temp.write(PNG_BYTES)
|
||||
temp.flush()
|
||||
response = client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['response'] == 'Because it is.'
|
||||
@@ -792,7 +811,7 @@ async def test_async_client_chat_images(httpserver: HTTPServer):
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Why is the sky blue?',
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
'images': [PNG_BASE64],
|
||||
},
|
||||
],
|
||||
'tools': [],
|
||||
@@ -810,12 +829,10 @@ async def test_async_client_chat_images(httpserver: HTTPServer):
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with io.BytesIO() as b:
|
||||
Image.new('RGB', (1, 1)).save(b, 'PNG')
|
||||
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['message']['role'] == 'assistant'
|
||||
assert response['message']['content'] == "I don't know."
|
||||
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [PNG_BYTES]}])
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['message']['role'] == 'assistant'
|
||||
assert response['message']['content'] == "I don't know."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -886,7 +903,7 @@ async def test_async_client_generate_images(httpserver: HTTPServer):
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'stream': False,
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
'images': [PNG_BASE64],
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
@@ -898,7 +915,8 @@ async def test_async_client_generate_images(httpserver: HTTPServer):
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
Image.new('RGB', (1, 1)).save(temp, 'PNG')
|
||||
temp.write(PNG_BYTES)
|
||||
temp.flush()
|
||||
response = await client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['response'] == 'Because it is.'
|
||||
|
||||
Reference in New Issue
Block a user