From 8d0d0e483d050c7d2a23aece922c6ca4ef5e1779 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Fri, 14 Feb 2025 09:44:43 -0800 Subject: [PATCH] client: add support for passing in Image type to generate (#408) --- ollama/_client.py | 12 ++++++------ tests/test_client.py | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/ollama/_client.py b/ollama/_client.py index cbe43c9..541d9c8 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -190,7 +190,7 @@ class Client(BaseClient): stream: Literal[False] = False, raw: bool = False, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, - images: Optional[Sequence[Union[str, bytes]]] = None, + images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> GenerateResponse: ... @@ -208,7 +208,7 @@ class Client(BaseClient): stream: Literal[True] = True, raw: bool = False, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, - images: Optional[Sequence[Union[str, bytes]]] = None, + images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> Iterator[GenerateResponse]: ... @@ -225,7 +225,7 @@ class Client(BaseClient): stream: bool = False, raw: Optional[bool] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, - images: Optional[Sequence[Union[str, bytes]]] = None, + images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> Union[GenerateResponse, Iterator[GenerateResponse]]: @@ -694,7 +694,7 @@ class AsyncClient(BaseClient): stream: Literal[False] = False, raw: bool = False, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, - images: Optional[Sequence[Union[str, bytes]]] = None, + images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> GenerateResponse: ... @@ -712,7 +712,7 @@ class AsyncClient(BaseClient): stream: Literal[True] = True, raw: bool = False, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, - images: Optional[Sequence[Union[str, bytes]]] = None, + images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> AsyncIterator[GenerateResponse]: ... @@ -729,7 +729,7 @@ class AsyncClient(BaseClient): stream: bool = False, raw: Optional[bool] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, - images: Optional[Sequence[Union[str, bytes]]] = None, + images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> Union[GenerateResponse, AsyncIterator[GenerateResponse]]: diff --git a/tests/test_client.py b/tests/test_client.py index dacb953..8890afd 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -11,6 +11,7 @@ from pytest_httpserver import HTTPServer, URIPattern from werkzeug.wrappers import Request, Response from ollama._client import CONNECTION_ERROR_MESSAGE, AsyncClient, Client, _copy_tools +from ollama._types import Image PNG_BASE64 = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC' PNG_BYTES = base64.b64decode(PNG_BASE64) @@ -286,6 +287,46 @@ def test_client_generate(httpserver: HTTPServer): assert response['response'] == 'Because it is.' +def test_client_generate_with_image_type(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy', + 'prompt': 'What is in this image?', + 'stream': False, + 'images': [PNG_BASE64], + }, + ).respond_with_json( + { + 'model': 'dummy', + 'response': 'A blue sky.', + } + ) + + client = Client(httpserver.url_for('/')) + response = client.generate('dummy', 'What is in this image?', images=[Image(value=PNG_BASE64)]) + assert response['model'] == 'dummy' + assert response['response'] == 'A blue sky.' + + +def test_client_generate_with_invalid_image(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy', + 'prompt': 'What is in this image?', + 'stream': False, + 'images': ['invalid_base64'], + }, + ).respond_with_json({'error': 'Invalid image data'}, status=400) + + client = Client(httpserver.url_for('/')) + with pytest.raises(ValueError): + client.generate('dummy', 'What is in this image?', images=[Image(value='invalid_base64')]) + + def test_client_generate_stream(httpserver: HTTPServer): def stream_handler(_: Request): def generate():