diff --git a/ollama/_client.py b/ollama/_client.py index 001162d..c5aa1bc 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -93,6 +93,7 @@ class Client(BaseClient): format: Literal['', 'json'] = '', images: Optional[Sequence[AnyStr]] = None, options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: """ Create a response using the requested model. @@ -121,6 +122,7 @@ class Client(BaseClient): 'images': [_encode_image(image) for image in images or []], 'format': format, 'options': options or {}, + 'keep_alive': keep_alive, }, stream=stream, ) @@ -132,6 +134,7 @@ class Client(BaseClient): stream: bool = False, format: Literal['', 'json'] = '', options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: """ Create a chat response using the requested model. @@ -165,11 +168,18 @@ class Client(BaseClient): 'stream': stream, 'format': format, 'options': options or {}, + 'keep_alive': keep_alive, }, stream=stream, ) - def embeddings(self, model: str = '', prompt: str = '', options: Optional[Options] = None) -> Sequence[float]: + def embeddings( + self, + model: str = '', + prompt: str = '', + options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, + ) -> Sequence[float]: return self._request( 'POST', '/api/embeddings', @@ -177,6 +187,7 @@ class Client(BaseClient): 'model': model, 'prompt': prompt, 'options': options or {}, + 'keep_alive': keep_alive, }, ).json() @@ -364,6 +375,7 @@ class AsyncClient(BaseClient): format: Literal['', 'json'] = '', images: Optional[Sequence[AnyStr]] = None, options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: """ Create a response using the requested model. @@ -391,6 +403,7 @@ class AsyncClient(BaseClient): 'images': [_encode_image(image) for image in images or []], 'format': format, 'options': options or {}, + 'keep_alive': keep_alive, }, stream=stream, ) @@ -402,6 +415,7 @@ class AsyncClient(BaseClient): stream: bool = False, format: Literal['', 'json'] = '', options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: """ Create a chat response using the requested model. @@ -434,11 +448,18 @@ class AsyncClient(BaseClient): 'stream': stream, 'format': format, 'options': options or {}, + 'keep_alive': keep_alive, }, stream=stream, ) - async def embeddings(self, model: str = '', prompt: str = '', options: Optional[Options] = None) -> Sequence[float]: + async def embeddings( + self, + model: str = '', + prompt: str = '', + options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, + ) -> Sequence[float]: response = await self._request( 'POST', '/api/embeddings', @@ -446,6 +467,7 @@ class AsyncClient(BaseClient): 'model': model, 'prompt': prompt, 'options': options or {}, + 'keep_alive': keep_alive, }, ) diff --git a/tests/test_client.py b/tests/test_client.py index 08aa789..158902f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -29,6 +29,7 @@ def test_client_chat(httpserver: HTTPServer): 'stream': False, 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_json( { @@ -75,6 +76,7 @@ def test_client_chat_stream(httpserver: HTTPServer): 'stream': True, 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_handler(stream_handler) @@ -103,6 +105,7 @@ def test_client_chat_images(httpserver: HTTPServer): 'stream': False, 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_json( { @@ -139,6 +142,7 @@ def test_client_generate(httpserver: HTTPServer): 'images': [], 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_json( { @@ -183,6 +187,7 @@ def test_client_generate_stream(httpserver: HTTPServer): 'images': [], 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_handler(stream_handler) @@ -210,6 +215,7 @@ def test_client_generate_images(httpserver: HTTPServer): 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_json( { @@ -513,6 +519,7 @@ async def test_async_client_chat(httpserver: HTTPServer): 'stream': False, 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_json({}) @@ -550,6 +557,7 @@ async def test_async_client_chat_stream(httpserver: HTTPServer): 'stream': True, 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_handler(stream_handler) @@ -579,6 +587,7 @@ async def test_async_client_chat_images(httpserver: HTTPServer): 'stream': False, 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_json({}) @@ -606,6 +615,7 @@ async def test_async_client_generate(httpserver: HTTPServer): 'images': [], 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_json({}) @@ -645,6 +655,7 @@ async def test_async_client_generate_stream(httpserver: HTTPServer): 'images': [], 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_handler(stream_handler) @@ -673,6 +684,7 @@ async def test_async_client_generate_images(httpserver: HTTPServer): 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_json({})