mirror of
https://github.com/ollama/ollama-python.git
synced 2026-01-13 21:57:16 +08:00
client: add support for passing in Image type to generate (#408)
This commit is contained in:
parent
0561f42701
commit
8d0d0e483d
@ -190,7 +190,7 @@ class Client(BaseClient):
|
||||
stream: Literal[False] = False,
|
||||
raw: bool = False,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> GenerateResponse: ...
|
||||
@ -208,7 +208,7 @@ class Client(BaseClient):
|
||||
stream: Literal[True] = True,
|
||||
raw: bool = False,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Iterator[GenerateResponse]: ...
|
||||
@ -225,7 +225,7 @@ class Client(BaseClient):
|
||||
stream: bool = False,
|
||||
raw: Optional[bool] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Union[GenerateResponse, Iterator[GenerateResponse]]:
|
||||
@ -694,7 +694,7 @@ class AsyncClient(BaseClient):
|
||||
stream: Literal[False] = False,
|
||||
raw: bool = False,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> GenerateResponse: ...
|
||||
@ -712,7 +712,7 @@ class AsyncClient(BaseClient):
|
||||
stream: Literal[True] = True,
|
||||
raw: bool = False,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> AsyncIterator[GenerateResponse]: ...
|
||||
@ -729,7 +729,7 @@ class AsyncClient(BaseClient):
|
||||
stream: bool = False,
|
||||
raw: Optional[bool] = None,
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes]]] = None,
|
||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Union[GenerateResponse, AsyncIterator[GenerateResponse]]:
|
||||
|
||||
@ -11,6 +11,7 @@ from pytest_httpserver import HTTPServer, URIPattern
|
||||
from werkzeug.wrappers import Request, Response
|
||||
|
||||
from ollama._client import CONNECTION_ERROR_MESSAGE, AsyncClient, Client, _copy_tools
|
||||
from ollama._types import Image
|
||||
|
||||
PNG_BASE64 = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'
|
||||
PNG_BYTES = base64.b64decode(PNG_BASE64)
|
||||
@ -286,6 +287,46 @@ def test_client_generate(httpserver: HTTPServer):
|
||||
assert response['response'] == 'Because it is.'
|
||||
|
||||
|
||||
def test_client_generate_with_image_type(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'What is in this image?',
|
||||
'stream': False,
|
||||
'images': [PNG_BASE64],
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'response': 'A blue sky.',
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy', 'What is in this image?', images=[Image(value=PNG_BASE64)])
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['response'] == 'A blue sky.'
|
||||
|
||||
|
||||
def test_client_generate_with_invalid_image(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'What is in this image?',
|
||||
'stream': False,
|
||||
'images': ['invalid_base64'],
|
||||
},
|
||||
).respond_with_json({'error': 'Invalid image data'}, status=400)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
with pytest.raises(ValueError):
|
||||
client.generate('dummy', 'What is in this image?', images=[Image(value='invalid_base64')])
|
||||
|
||||
|
||||
def test_client_generate_stream(httpserver: HTTPServer):
|
||||
def stream_handler(_: Request):
|
||||
def generate():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user