client: add support for passing in Image type to generate (#408)

This commit is contained in:
Parth Sareen 2025-02-14 09:44:43 -08:00 committed by GitHub
parent 0561f42701
commit 8d0d0e483d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 6 deletions

View File

@ -190,7 +190,7 @@ class Client(BaseClient):
stream: Literal[False] = False,
raw: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> GenerateResponse: ...
@ -208,7 +208,7 @@ class Client(BaseClient):
stream: Literal[True] = True,
raw: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Iterator[GenerateResponse]: ...
@ -225,7 +225,7 @@ class Client(BaseClient):
stream: bool = False,
raw: Optional[bool] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[GenerateResponse, Iterator[GenerateResponse]]:
@ -694,7 +694,7 @@ class AsyncClient(BaseClient):
stream: Literal[False] = False,
raw: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> GenerateResponse: ...
@ -712,7 +712,7 @@ class AsyncClient(BaseClient):
stream: Literal[True] = True,
raw: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> AsyncIterator[GenerateResponse]: ...
@ -729,7 +729,7 @@ class AsyncClient(BaseClient):
stream: bool = False,
raw: Optional[bool] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[GenerateResponse, AsyncIterator[GenerateResponse]]:

View File

@ -11,6 +11,7 @@ from pytest_httpserver import HTTPServer, URIPattern
from werkzeug.wrappers import Request, Response
from ollama._client import CONNECTION_ERROR_MESSAGE, AsyncClient, Client, _copy_tools
from ollama._types import Image
PNG_BASE64 = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'
PNG_BYTES = base64.b64decode(PNG_BASE64)
@ -286,6 +287,46 @@ def test_client_generate(httpserver: HTTPServer):
assert response['response'] == 'Because it is.'
def test_client_generate_with_image_type(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy',
'prompt': 'What is in this image?',
'stream': False,
'images': [PNG_BASE64],
},
).respond_with_json(
{
'model': 'dummy',
'response': 'A blue sky.',
}
)
client = Client(httpserver.url_for('/'))
response = client.generate('dummy', 'What is in this image?', images=[Image(value=PNG_BASE64)])
assert response['model'] == 'dummy'
assert response['response'] == 'A blue sky.'
def test_client_generate_with_invalid_image(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy',
'prompt': 'What is in this image?',
'stream': False,
'images': ['invalid_base64'],
},
).respond_with_json({'error': 'Invalid image data'}, status=400)
client = Client(httpserver.url_for('/'))
with pytest.raises(ValueError):
client.generate('dummy', 'What is in this image?', images=[Image(value='invalid_base64')])
def test_client_generate_stream(httpserver: HTTPServer):
def stream_handler(_: Request):
def generate():