add keep_alive

This commit is contained in:
Michael Yang 2024-01-26 11:19:52 -08:00
parent f618a2f448
commit fbb6553e03
2 changed files with 36 additions and 2 deletions

View File

@ -92,6 +92,7 @@ class Client(BaseClient):
format: Literal['', 'json'] = '',
images: Optional[Sequence[AnyStr]] = None,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Create a response using the requested model.
@ -120,6 +121,7 @@ class Client(BaseClient):
'images': [_encode_image(image) for image in images or []],
'format': format,
'options': options or {},
'keep_alive': keep_alive,
},
stream=stream,
)
@ -131,6 +133,7 @@ class Client(BaseClient):
stream: bool = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Create a chat response using the requested model.
@ -164,11 +167,18 @@ class Client(BaseClient):
'stream': stream,
'format': format,
'options': options or {},
'keep_alive': keep_alive,
},
stream=stream,
)
def embeddings(self, model: str = '', prompt: str = '', options: Optional[Options] = None) -> Sequence[float]:
def embeddings(
self,
model: str = '',
prompt: str = '',
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Sequence[float]:
return self._request(
'POST',
'/api/embeddings',
@ -176,6 +186,7 @@ class Client(BaseClient):
'model': model,
'prompt': prompt,
'options': options or {},
'keep_alive': keep_alive,
},
).json()
@ -360,6 +371,7 @@ class AsyncClient(BaseClient):
format: Literal['', 'json'] = '',
images: Optional[Sequence[AnyStr]] = None,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Create a response using the requested model.
@ -387,6 +399,7 @@ class AsyncClient(BaseClient):
'images': [_encode_image(image) for image in images or []],
'format': format,
'options': options or {},
'keep_alive': keep_alive,
},
stream=stream,
)
@ -398,6 +411,7 @@ class AsyncClient(BaseClient):
stream: bool = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Create a chat response using the requested model.
@ -430,11 +444,18 @@ class AsyncClient(BaseClient):
'stream': stream,
'format': format,
'options': options or {},
'keep_alive': keep_alive,
},
stream=stream,
)
async def embeddings(self, model: str = '', prompt: str = '', options: Optional[Options] = None) -> Sequence[float]:
async def embeddings(
self,
model: str = '',
prompt: str = '',
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Sequence[float]:
response = await self._request(
'POST',
'/api/embeddings',
@ -442,6 +463,7 @@ class AsyncClient(BaseClient):
'model': model,
'prompt': prompt,
'options': options or {},
'keep_alive': keep_alive,
},
)

View File

@ -29,6 +29,7 @@ def test_client_chat(httpserver: HTTPServer):
'stream': False,
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_json(
{
@ -75,6 +76,7 @@ def test_client_chat_stream(httpserver: HTTPServer):
'stream': True,
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_handler(stream_handler)
@ -103,6 +105,7 @@ def test_client_chat_images(httpserver: HTTPServer):
'stream': False,
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_json(
{
@ -139,6 +142,7 @@ def test_client_generate(httpserver: HTTPServer):
'images': [],
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_json(
{
@ -183,6 +187,7 @@ def test_client_generate_stream(httpserver: HTTPServer):
'images': [],
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_handler(stream_handler)
@ -210,6 +215,7 @@ def test_client_generate_images(httpserver: HTTPServer):
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_json(
{
@ -465,6 +471,7 @@ async def test_async_client_chat(httpserver: HTTPServer):
'stream': False,
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_json({})
@ -502,6 +509,7 @@ async def test_async_client_chat_stream(httpserver: HTTPServer):
'stream': True,
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_handler(stream_handler)
@ -531,6 +539,7 @@ async def test_async_client_chat_images(httpserver: HTTPServer):
'stream': False,
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_json({})
@ -558,6 +567,7 @@ async def test_async_client_generate(httpserver: HTTPServer):
'images': [],
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_json({})
@ -597,6 +607,7 @@ async def test_async_client_generate_stream(httpserver: HTTPServer):
'images': [],
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_handler(stream_handler)
@ -625,6 +636,7 @@ async def test_async_client_generate_images(httpserver: HTTPServer):
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_json({})