bugfix: fix passing Image type in messages for chat (#390)

---------

Co-authored-by: Aarni Koskela <akx@iki.fi>
This commit is contained in:
Parth Sareen
2024-12-29 14:43:07 -08:00
committed by GitHub
parent 7d1e002be9
commit ee349ecc6d
4 changed files with 49 additions and 124 deletions
+39 -21
View File
@@ -1,5 +1,5 @@
import base64
import os
import io
import json
from pydantic import ValidationError, BaseModel
import pytest
@@ -7,10 +7,12 @@ import tempfile
from pathlib import Path
from pytest_httpserver import HTTPServer, URIPattern
from werkzeug.wrappers import Request, Response
from PIL import Image
from ollama._client import Client, AsyncClient, _copy_tools
PNG_BASE64 = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'
PNG_BYTES = base64.b64decode(PNG_BASE64)
class PrefixPattern(URIPattern):
def __init__(self, prefix: str):
@@ -86,7 +88,11 @@ def test_client_chat_stream(httpserver: HTTPServer):
assert part['message']['content'] == next(it)
def test_client_chat_images(httpserver: HTTPServer):
@pytest.mark.parametrize('message_format', ('dict', 'pydantic_model'))
@pytest.mark.parametrize('file_style', ('path', 'bytes'))
def test_client_chat_images(httpserver: HTTPServer, message_format: str, file_style: str, tmp_path):
from ollama._types import Message, Image
httpserver.expect_ordered_request(
'/api/chat',
method='POST',
@@ -96,7 +102,7 @@ def test_client_chat_images(httpserver: HTTPServer):
{
'role': 'user',
'content': 'Why is the sky blue?',
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
'images': [PNG_BASE64],
},
],
'tools': [],
@@ -114,12 +120,24 @@ def test_client_chat_images(httpserver: HTTPServer):
client = Client(httpserver.url_for('/'))
with io.BytesIO() as b:
Image.new('RGB', (1, 1)).save(b, 'PNG')
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
assert response['model'] == 'dummy'
assert response['message']['role'] == 'assistant'
assert response['message']['content'] == "I don't know."
if file_style == 'bytes':
image_content = PNG_BYTES
elif file_style == 'path':
image_path = tmp_path / 'transparent.png'
image_path.write_bytes(PNG_BYTES)
image_content = str(image_path)
if message_format == 'pydantic_model':
messages = [Message(role='user', content='Why is the sky blue?', images=[Image(value=image_content)])]
elif message_format == 'dict':
messages = [{'role': 'user', 'content': 'Why is the sky blue?', 'images': [image_content]}]
else:
raise ValueError(f'Invalid message format: {message_format}')
response = client.chat('dummy', messages=messages)
assert response['model'] == 'dummy'
assert response['message']['role'] == 'assistant'
assert response['message']['content'] == "I don't know."
def test_client_chat_format_json(httpserver: HTTPServer):
@@ -309,7 +327,7 @@ def test_client_generate_images(httpserver: HTTPServer):
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'stream': False,
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
'images': [PNG_BASE64],
},
).respond_with_json(
{
@@ -321,7 +339,8 @@ def test_client_generate_images(httpserver: HTTPServer):
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as temp:
Image.new('RGB', (1, 1)).save(temp, 'PNG')
temp.write(PNG_BYTES)
temp.flush()
response = client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
assert response['model'] == 'dummy'
assert response['response'] == 'Because it is.'
@@ -792,7 +811,7 @@ async def test_async_client_chat_images(httpserver: HTTPServer):
{
'role': 'user',
'content': 'Why is the sky blue?',
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
'images': [PNG_BASE64],
},
],
'tools': [],
@@ -810,12 +829,10 @@ async def test_async_client_chat_images(httpserver: HTTPServer):
client = AsyncClient(httpserver.url_for('/'))
with io.BytesIO() as b:
Image.new('RGB', (1, 1)).save(b, 'PNG')
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
assert response['model'] == 'dummy'
assert response['message']['role'] == 'assistant'
assert response['message']['content'] == "I don't know."
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [PNG_BYTES]}])
assert response['model'] == 'dummy'
assert response['message']['role'] == 'assistant'
assert response['message']['content'] == "I don't know."
@pytest.mark.asyncio
@@ -886,7 +903,7 @@ async def test_async_client_generate_images(httpserver: HTTPServer):
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'stream': False,
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
'images': [PNG_BASE64],
},
).respond_with_json(
{
@@ -898,7 +915,8 @@ async def test_async_client_generate_images(httpserver: HTTPServer):
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as temp:
Image.new('RGB', (1, 1)).save(temp, 'PNG')
temp.write(PNG_BYTES)
temp.flush()
response = await client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
assert response['model'] == 'dummy'
assert response['response'] == 'Because it is.'