mirror of
https://github.com/ollama/ollama-python.git
synced 2026-06-16 21:24:52 +00:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| fcdf5771f5 | |||
| ec8bf88c2b | |||
| 8b929ab496 | |||
| eee32dda37 | |||
| c74dd5835d | |||
| cdec2ad99e | |||
| 4a81fa43ee | |||
| 98ad0d884e | |||
| c27eebc158 | |||
| 46291d49a7 | |||
| cf3ab807c8 | |||
| 8e5d431d0d | |||
| e201181d4c | |||
| c077b5d685 | |||
| fbb6553e03 | |||
| f618a2f448 | |||
| 354f012168 |
@@ -17,7 +17,8 @@ jobs:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: poetry
|
||||
- run: poetry install --with=dev
|
||||
- run: poetry run ruff --output-format=github .
|
||||
- run: poetry run ruff check --output-format=github .
|
||||
- run: poetry run ruff format --check .
|
||||
- run: poetry run pytest . --junitxml=junit/test-results-${{ matrix.python-version }}.xml --cov=ollama --cov-report=xml --cov-report=html
|
||||
- name: check poetry.lock is up-to-date
|
||||
run: poetry check --lock
|
||||
|
||||
@@ -138,7 +138,7 @@ async def chat():
|
||||
asyncio.run(chat())
|
||||
```
|
||||
|
||||
Setting `stream=True`` modifies functions to return a Python asynchronous generator:
|
||||
Setting `stream=True` modifies functions to return a Python asynchronous generator:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
@@ -162,7 +162,7 @@ model = 'does-not-yet-exist'
|
||||
try:
|
||||
ollama.chat(model)
|
||||
except ollama.ResponseError as e:
|
||||
print('Error:', e.content)
|
||||
print('Error:', e.error)
|
||||
if e.status_code == 404:
|
||||
ollama.pull(model)
|
||||
```
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
import sys
|
||||
|
||||
from ollama import create
|
||||
|
||||
|
||||
args = sys.argv[1:]
|
||||
if len(args) == 2:
|
||||
# create from local file
|
||||
path = args[1]
|
||||
else:
|
||||
print('usage: python main.py <name> <filepath>')
|
||||
sys.exit(1)
|
||||
|
||||
# TODO: update to real Modelfile values
|
||||
modelfile = f"""
|
||||
FROM {path}
|
||||
"""
|
||||
|
||||
for response in create(model=args[0], modelfile=modelfile, stream=True):
|
||||
print(response['status'])
|
||||
+91
-23
@@ -2,11 +2,13 @@ import os
|
||||
import io
|
||||
import json
|
||||
import httpx
|
||||
import binascii
|
||||
import platform
|
||||
import urllib.parse
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from hashlib import sha256
|
||||
from base64 import b64encode
|
||||
from base64 import b64encode, b64decode
|
||||
|
||||
from typing import Any, AnyStr, Union, Optional, Sequence, Mapping, Literal
|
||||
|
||||
@@ -17,6 +19,13 @@ if sys.version_info < (3, 9):
|
||||
else:
|
||||
from collections.abc import Iterator, AsyncIterator
|
||||
|
||||
from importlib import metadata
|
||||
|
||||
try:
|
||||
__version__ = metadata.version('ollama')
|
||||
except metadata.PackageNotFoundError:
|
||||
__version__ = '0.0.0'
|
||||
|
||||
from ollama._types import Message, Options, RequestError, ResponseError
|
||||
|
||||
|
||||
@@ -36,10 +45,17 @@ class BaseClient:
|
||||
- `timeout`: None
|
||||
`kwargs` are passed to the httpx client.
|
||||
"""
|
||||
|
||||
headers = kwargs.pop('headers', {})
|
||||
headers['Content-Type'] = 'application/json'
|
||||
headers['Accept'] = 'application/json'
|
||||
headers['User-Agent'] = f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}'
|
||||
|
||||
self._client = client(
|
||||
base_url=_parse_host(host or os.getenv('OLLAMA_HOST')),
|
||||
follow_redirects=follow_redirects,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -92,6 +108,7 @@ class Client(BaseClient):
|
||||
format: Literal['', 'json'] = '',
|
||||
images: Optional[Sequence[AnyStr]] = None,
|
||||
options: Optional[Options] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
"""
|
||||
Create a response using the requested model.
|
||||
@@ -120,6 +137,7 @@ class Client(BaseClient):
|
||||
'images': [_encode_image(image) for image in images or []],
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
'keep_alive': keep_alive,
|
||||
},
|
||||
stream=stream,
|
||||
)
|
||||
@@ -131,6 +149,7 @@ class Client(BaseClient):
|
||||
stream: bool = False,
|
||||
format: Literal['', 'json'] = '',
|
||||
options: Optional[Options] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
"""
|
||||
Create a chat response using the requested model.
|
||||
@@ -164,11 +183,18 @@ class Client(BaseClient):
|
||||
'stream': stream,
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
'keep_alive': keep_alive,
|
||||
},
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def embeddings(self, model: str = '', prompt: str = '', options: Optional[Options] = None) -> Sequence[float]:
|
||||
def embeddings(
|
||||
self,
|
||||
model: str = '',
|
||||
prompt: str = '',
|
||||
options: Optional[Options] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Sequence[float]:
|
||||
return self._request(
|
||||
'POST',
|
||||
'/api/embeddings',
|
||||
@@ -176,6 +202,7 @@ class Client(BaseClient):
|
||||
'model': model,
|
||||
'prompt': prompt,
|
||||
'options': options or {},
|
||||
'keep_alive': keep_alive,
|
||||
},
|
||||
).json()
|
||||
|
||||
@@ -259,13 +286,16 @@ class Client(BaseClient):
|
||||
out = io.StringIO()
|
||||
for line in io.StringIO(modelfile):
|
||||
command, _, args = line.partition(' ')
|
||||
if command.upper() in ['FROM', 'ADAPTER']:
|
||||
path = Path(args.strip()).expanduser()
|
||||
path = path if path.is_absolute() else base / path
|
||||
if path.exists():
|
||||
args = f'@{self._create_blob(path)}'
|
||||
if command.upper() not in ['FROM', 'ADAPTER']:
|
||||
print(line, end='', file=out)
|
||||
continue
|
||||
|
||||
path = Path(args.strip()).expanduser()
|
||||
path = path if path.is_absolute() else base / path
|
||||
if path.exists():
|
||||
args = f'@{self._create_blob(path)}\n'
|
||||
print(command, args, end='', file=out)
|
||||
|
||||
print(command, args, file=out)
|
||||
return out.getvalue()
|
||||
|
||||
def _create_blob(self, path: Union[str, Path]) -> str:
|
||||
@@ -360,6 +390,7 @@ class AsyncClient(BaseClient):
|
||||
format: Literal['', 'json'] = '',
|
||||
images: Optional[Sequence[AnyStr]] = None,
|
||||
options: Optional[Options] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
"""
|
||||
Create a response using the requested model.
|
||||
@@ -387,6 +418,7 @@ class AsyncClient(BaseClient):
|
||||
'images': [_encode_image(image) for image in images or []],
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
'keep_alive': keep_alive,
|
||||
},
|
||||
stream=stream,
|
||||
)
|
||||
@@ -398,6 +430,7 @@ class AsyncClient(BaseClient):
|
||||
stream: bool = False,
|
||||
format: Literal['', 'json'] = '',
|
||||
options: Optional[Options] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
"""
|
||||
Create a chat response using the requested model.
|
||||
@@ -430,11 +463,18 @@ class AsyncClient(BaseClient):
|
||||
'stream': stream,
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
'keep_alive': keep_alive,
|
||||
},
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
async def embeddings(self, model: str = '', prompt: str = '', options: Optional[Options] = None) -> Sequence[float]:
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str = '',
|
||||
prompt: str = '',
|
||||
options: Optional[Options] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Sequence[float]:
|
||||
response = await self._request(
|
||||
'POST',
|
||||
'/api/embeddings',
|
||||
@@ -442,6 +482,7 @@ class AsyncClient(BaseClient):
|
||||
'model': model,
|
||||
'prompt': prompt,
|
||||
'options': options or {},
|
||||
'keep_alive': keep_alive,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -527,13 +568,16 @@ class AsyncClient(BaseClient):
|
||||
out = io.StringIO()
|
||||
for line in io.StringIO(modelfile):
|
||||
command, _, args = line.partition(' ')
|
||||
if command.upper() in ['FROM', 'ADAPTER']:
|
||||
path = Path(args).expanduser()
|
||||
path = path if path.is_absolute() else base / path
|
||||
if path.exists():
|
||||
args = f'@{await self._create_blob(path)}'
|
||||
if command.upper() not in ['FROM', 'ADAPTER']:
|
||||
print(line, end='', file=out)
|
||||
continue
|
||||
|
||||
path = Path(args.strip()).expanduser()
|
||||
path = path if path.is_absolute() else base / path
|
||||
if path.exists():
|
||||
args = f'@{await self._create_blob(path)}\n'
|
||||
print(command, args, end='', file=out)
|
||||
|
||||
print(command, args, file=out)
|
||||
return out.getvalue()
|
||||
|
||||
async def _create_blob(self, path: Union[str, Path]) -> str:
|
||||
@@ -583,19 +627,43 @@ class AsyncClient(BaseClient):
|
||||
|
||||
|
||||
def _encode_image(image) -> str:
|
||||
if p := _as_path(image):
|
||||
b64 = b64encode(p.read_bytes())
|
||||
elif b := _as_bytesio(image):
|
||||
b64 = b64encode(b.read())
|
||||
else:
|
||||
raise RequestError('images must be a list of bytes, path-like objects, or file-like objects')
|
||||
"""
|
||||
>>> _encode_image(b'ollama')
|
||||
'b2xsYW1h'
|
||||
>>> _encode_image(io.BytesIO(b'ollama'))
|
||||
'b2xsYW1h'
|
||||
>>> _encode_image('LICENSE')
|
||||
'TUlUIExpY2Vuc2UKCkNvcHlyaWdodCAoYykgT2xsYW1hCgpQZXJtaXNzaW9uIGlzIGhlcmVieSBncmFudGVkLCBmcmVlIG9mIGNoYXJnZSwgdG8gYW55IHBlcnNvbiBvYnRhaW5pbmcgYSBjb3B5Cm9mIHRoaXMgc29mdHdhcmUgYW5kIGFzc29jaWF0ZWQgZG9jdW1lbnRhdGlvbiBmaWxlcyAodGhlICJTb2Z0d2FyZSIpLCB0byBkZWFsCmluIHRoZSBTb2Z0d2FyZSB3aXRob3V0IHJlc3RyaWN0aW9uLCBpbmNsdWRpbmcgd2l0aG91dCBsaW1pdGF0aW9uIHRoZSByaWdodHMKdG8gdXNlLCBjb3B5LCBtb2RpZnksIG1lcmdlLCBwdWJsaXNoLCBkaXN0cmlidXRlLCBzdWJsaWNlbnNlLCBhbmQvb3Igc2VsbApjb3BpZXMgb2YgdGhlIFNvZnR3YXJlLCBhbmQgdG8gcGVybWl0IHBlcnNvbnMgdG8gd2hvbSB0aGUgU29mdHdhcmUgaXMKZnVybmlzaGVkIHRvIGRvIHNvLCBzdWJqZWN0IHRvIHRoZSBmb2xsb3dpbmcgY29uZGl0aW9uczoKClRoZSBhYm92ZSBjb3B5cmlnaHQgbm90aWNlIGFuZCB0aGlzIHBlcm1pc3Npb24gbm90aWNlIHNoYWxsIGJlIGluY2x1ZGVkIGluIGFsbApjb3BpZXMgb3Igc3Vic3RhbnRpYWwgcG9ydGlvbnMgb2YgdGhlIFNvZnR3YXJlLgoKVEhFIFNPRlRXQVJFIElTIFBST1ZJREVEICJBUyBJUyIsIFdJVEhPVVQgV0FSUkFOVFkgT0YgQU5ZIEtJTkQsIEVYUFJFU1MgT1IKSU1QTElFRCwgSU5DTFVESU5HIEJVVCBOT1QgTElNSVRFRCBUTyBUSEUgV0FSUkFOVElFUyBPRiBNRVJDSEFOVEFCSUxJVFksCkZJVE5FU1MgRk9SIEEgUEFSVElDVUxBUiBQVVJQT1NFIEFORCBOT05JTkZSSU5HRU1FTlQuIElOIE5PIEVWRU5UIFNIQUxMIFRIRQpBVVRIT1JTIE9SIENPUFlSSUdIVCBIT0xERVJTIEJFIExJQUJMRSBGT1IgQU5ZIENMQUlNLCBEQU1BR0VTIE9SIE9USEVSCkxJQUJJTElUWSwgV0hFVEhFUiBJTiBBTiBBQ1RJT04gT0YgQ09OVFJBQ1QsIFRPUlQgT1IgT1RIRVJXSVNFLCBBUklTSU5HIEZST00sCk9VVCBPRiBPUiBJTiBDT05ORUNUSU9OIFdJVEggVEhFIFNPRlRXQVJFIE9SIFRIRSBVU0UgT1IgT1RIRVIgREVBTElOR1MgSU4gVEhFClNPRlRXQVJFLgo='
|
||||
>>> _encode_image(Path('LICENSE'))
|
||||
'TUlUIExpY2Vuc2UKCkNvcHlyaWdodCAoYykgT2xsYW1hCgpQZXJtaXNzaW9uIGlzIGhlcmVieSBncmFudGVkLCBmcmVlIG9mIGNoYXJnZSwgdG8gYW55IHBlcnNvbiBvYnRhaW5pbmcgYSBjb3B5Cm9mIHRoaXMgc29mdHdhcmUgYW5kIGFzc29jaWF0ZWQgZG9jdW1lbnRhdGlvbiBmaWxlcyAodGhlICJTb2Z0d2FyZSIpLCB0byBkZWFsCmluIHRoZSBTb2Z0d2FyZSB3aXRob3V0IHJlc3RyaWN0aW9uLCBpbmNsdWRpbmcgd2l0aG91dCBsaW1pdGF0aW9uIHRoZSByaWdodHMKdG8gdXNlLCBjb3B5LCBtb2RpZnksIG1lcmdlLCBwdWJsaXNoLCBkaXN0cmlidXRlLCBzdWJsaWNlbnNlLCBhbmQvb3Igc2VsbApjb3BpZXMgb2YgdGhlIFNvZnR3YXJlLCBhbmQgdG8gcGVybWl0IHBlcnNvbnMgdG8gd2hvbSB0aGUgU29mdHdhcmUgaXMKZnVybmlzaGVkIHRvIGRvIHNvLCBzdWJqZWN0IHRvIHRoZSBmb2xsb3dpbmcgY29uZGl0aW9uczoKClRoZSBhYm92ZSBjb3B5cmlnaHQgbm90aWNlIGFuZCB0aGlzIHBlcm1pc3Npb24gbm90aWNlIHNoYWxsIGJlIGluY2x1ZGVkIGluIGFsbApjb3BpZXMgb3Igc3Vic3RhbnRpYWwgcG9ydGlvbnMgb2YgdGhlIFNvZnR3YXJlLgoKVEhFIFNPRlRXQVJFIElTIFBST1ZJREVEICJBUyBJUyIsIFdJVEhPVVQgV0FSUkFOVFkgT0YgQU5ZIEtJTkQsIEVYUFJFU1MgT1IKSU1QTElFRCwgSU5DTFVESU5HIEJVVCBOT1QgTElNSVRFRCBUTyBUSEUgV0FSUkFOVElFUyBPRiBNRVJDSEFOVEFCSUxJVFksCkZJVE5FU1MgRk9SIEEgUEFSVElDVUxBUiBQVVJQT1NFIEFORCBOT05JTkZSSU5HRU1FTlQuIElOIE5PIEVWRU5UIFNIQUxMIFRIRQpBVVRIT1JTIE9SIENPUFlSSUdIVCBIT0xERVJTIEJFIExJQUJMRSBGT1IgQU5ZIENMQUlNLCBEQU1BR0VTIE9SIE9USEVSCkxJQUJJTElUWSwgV0hFVEhFUiBJTiBBTiBBQ1RJT04gT0YgQ09OVFJBQ1QsIFRPUlQgT1IgT1RIRVJXSVNFLCBBUklTSU5HIEZST00sCk9VVCBPRiBPUiBJTiBDT05ORUNUSU9OIFdJVEggVEhFIFNPRlRXQVJFIE9SIFRIRSBVU0UgT1IgT1RIRVIgREVBTElOR1MgSU4gVEhFClNPRlRXQVJFLgo='
|
||||
>>> _encode_image('YWJj')
|
||||
'YWJj'
|
||||
>>> _encode_image(b'YWJj')
|
||||
'YWJj'
|
||||
"""
|
||||
|
||||
return b64.decode('utf-8')
|
||||
if p := _as_path(image):
|
||||
return b64encode(p.read_bytes()).decode('utf-8')
|
||||
|
||||
try:
|
||||
b64decode(image, validate=True)
|
||||
return image if isinstance(image, str) else image.decode('utf-8')
|
||||
except (binascii.Error, TypeError):
|
||||
...
|
||||
|
||||
if b := _as_bytesio(image):
|
||||
return b64encode(b.read()).decode('utf-8')
|
||||
|
||||
raise RequestError('image must be bytes, path-like object, or file-like object')
|
||||
|
||||
|
||||
def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:
|
||||
if isinstance(s, str) or isinstance(s, Path):
|
||||
return Path(s)
|
||||
try:
|
||||
if (p := Path(s)).exists():
|
||||
return p
|
||||
except Exception:
|
||||
...
|
||||
return None
|
||||
|
||||
|
||||
|
||||
+111
-2
@@ -29,6 +29,7 @@ def test_client_chat(httpserver: HTTPServer):
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
'keep_alive': None,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
@@ -75,6 +76,7 @@ def test_client_chat_stream(httpserver: HTTPServer):
|
||||
'stream': True,
|
||||
'format': '',
|
||||
'options': {},
|
||||
'keep_alive': None,
|
||||
},
|
||||
).respond_with_handler(stream_handler)
|
||||
|
||||
@@ -103,6 +105,7 @@ def test_client_chat_images(httpserver: HTTPServer):
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
'keep_alive': None,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
@@ -139,6 +142,7 @@ def test_client_generate(httpserver: HTTPServer):
|
||||
'images': [],
|
||||
'format': '',
|
||||
'options': {},
|
||||
'keep_alive': None,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
@@ -183,6 +187,7 @@ def test_client_generate_stream(httpserver: HTTPServer):
|
||||
'images': [],
|
||||
'format': '',
|
||||
'options': {},
|
||||
'keep_alive': None,
|
||||
},
|
||||
).respond_with_handler(stream_handler)
|
||||
|
||||
@@ -210,6 +215,7 @@ def test_client_generate_images(httpserver: HTTPServer):
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
'format': '',
|
||||
'options': {},
|
||||
'keep_alive': None,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
@@ -416,13 +422,61 @@ def test_client_create_modelfile(httpserver: HTTPServer):
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_modelfile_roundtrip(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'name': 'dummy',
|
||||
'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
|
||||
TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>
|
||||
{{.Prompt}} [/INST]"""
|
||||
SYSTEM """
|
||||
Use
|
||||
multiline
|
||||
strings.
|
||||
"""
|
||||
PARAMETER stop [INST]
|
||||
PARAMETER stop [/INST]
|
||||
PARAMETER stop <<SYS>>
|
||||
PARAMETER stop <</SYS>>''',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = client.create(
|
||||
'dummy',
|
||||
modelfile='\n'.join(
|
||||
[
|
||||
f'FROM {blob.name}',
|
||||
'TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>',
|
||||
'{{.Prompt}} [/INST]"""',
|
||||
'SYSTEM """',
|
||||
'Use',
|
||||
'multiline',
|
||||
'strings.',
|
||||
'"""',
|
||||
'PARAMETER stop [INST]',
|
||||
'PARAMETER stop [/INST]',
|
||||
'PARAMETER stop <<SYS>>',
|
||||
'PARAMETER stop <</SYS>>',
|
||||
]
|
||||
),
|
||||
)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_from_library(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'name': 'dummy',
|
||||
'modelfile': 'FROM llama2\n',
|
||||
'modelfile': 'FROM llama2',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({})
|
||||
@@ -465,6 +519,7 @@ async def test_async_client_chat(httpserver: HTTPServer):
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
'keep_alive': None,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
@@ -502,6 +557,7 @@ async def test_async_client_chat_stream(httpserver: HTTPServer):
|
||||
'stream': True,
|
||||
'format': '',
|
||||
'options': {},
|
||||
'keep_alive': None,
|
||||
},
|
||||
).respond_with_handler(stream_handler)
|
||||
|
||||
@@ -531,6 +587,7 @@ async def test_async_client_chat_images(httpserver: HTTPServer):
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
'keep_alive': None,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
@@ -558,6 +615,7 @@ async def test_async_client_generate(httpserver: HTTPServer):
|
||||
'images': [],
|
||||
'format': '',
|
||||
'options': {},
|
||||
'keep_alive': None,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
@@ -597,6 +655,7 @@ async def test_async_client_generate_stream(httpserver: HTTPServer):
|
||||
'images': [],
|
||||
'format': '',
|
||||
'options': {},
|
||||
'keep_alive': None,
|
||||
},
|
||||
).respond_with_handler(stream_handler)
|
||||
|
||||
@@ -625,6 +684,7 @@ async def test_async_client_generate_images(httpserver: HTTPServer):
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
'format': '',
|
||||
'options': {},
|
||||
'keep_alive': None,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
@@ -820,6 +880,55 @@ async def test_async_client_create_modelfile(httpserver: HTTPServer):
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_modelfile_roundtrip(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'name': 'dummy',
|
||||
'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
|
||||
TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>
|
||||
{{.Prompt}} [/INST]"""
|
||||
SYSTEM """
|
||||
Use
|
||||
multiline
|
||||
strings.
|
||||
"""
|
||||
PARAMETER stop [INST]
|
||||
PARAMETER stop [/INST]
|
||||
PARAMETER stop <<SYS>>
|
||||
PARAMETER stop <</SYS>>''',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = await client.create(
|
||||
'dummy',
|
||||
modelfile='\n'.join(
|
||||
[
|
||||
f'FROM {blob.name}',
|
||||
'TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>',
|
||||
'{{.Prompt}} [/INST]"""',
|
||||
'SYSTEM """',
|
||||
'Use',
|
||||
'multiline',
|
||||
'strings.',
|
||||
'"""',
|
||||
'PARAMETER stop [INST]',
|
||||
'PARAMETER stop [/INST]',
|
||||
'PARAMETER stop <<SYS>>',
|
||||
'PARAMETER stop <</SYS>>',
|
||||
]
|
||||
),
|
||||
)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_from_library(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
@@ -827,7 +936,7 @@ async def test_async_client_create_from_library(httpserver: HTTPServer):
|
||||
method='POST',
|
||||
json={
|
||||
'name': 'dummy',
|
||||
'modelfile': 'FROM llama2\n',
|
||||
'modelfile': 'FROM llama2',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
Reference in New Issue
Block a user