async client

This commit is contained in:
Michael Yang 2023-12-20 15:28:23 -08:00
parent 21578e3c5e
commit 20db23d932
7 changed files with 1009 additions and 482 deletions

View File

@ -1,4 +1,19 @@
from ollama.client import Client
from ._client import Client, AsyncClient
__all__ = [
'Client',
'AsyncClient',
'generate',
'chat',
'pull',
'push',
'create',
'delete',
'list',
'copy',
'show',
]
_default_client = Client()

337
ollama/_client.py Normal file
View File

@ -0,0 +1,337 @@
import io
import json
import httpx
from pathlib import Path
from hashlib import sha256
from base64 import b64encode
class BaseClient:
def __init__(self, client, base_url='http://127.0.0.1:11434'):
self._client = client(base_url=base_url, follow_redirects=True, timeout=None)
class Client(BaseClient):
def __init__(self, base='http://localhost:11434'):
super().__init__(httpx.Client, base)
def _request(self, method, url, **kwargs):
response = self._client.request(method, url, **kwargs)
response.raise_for_status()
return response
def _request_json(self, method, url, **kwargs):
return self._request(method, url, **kwargs).json()
def _stream(self, method, url, **kwargs):
with self._client.stream(method, url, **kwargs) as r:
for line in r.iter_lines():
part = json.loads(line)
if e := part.get('error'):
raise Exception(e)
yield part
def generate(self, model='', prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None):
if not model:
raise Exception('must provide a model')
fn = self._stream if stream else self._request_json
return fn('POST', '/api/generate', json={
'model': model,
'prompt': prompt,
'system': system,
'template': template,
'context': context or [],
'stream': stream,
'raw': raw,
'images': [_encode_image(image) for image in images or []],
'format': format,
'options': options or {},
})
def chat(self, model='', messages=None, stream=False, format='', options=None):
if not model:
raise Exception('must provide a model')
for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of strings')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
if not message.get('content'):
raise Exception('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]
fn = self._stream if stream else self._request_json
return fn('POST', '/api/chat', json={
'model': model,
'messages': messages,
'stream': stream,
'format': format,
'options': options or {},
})
def pull(self, model, insecure=False, stream=False):
fn = self._stream if stream else self._request_json
return fn('POST', '/api/pull', json={
'model': model,
'insecure': insecure,
'stream': stream,
})
def push(self, model, insecure=False, stream=False):
fn = self._stream if stream else self._request_json
return fn('POST', '/api/push', json={
'model': model,
'insecure': insecure,
'stream': stream,
})
def create(self, model, path=None, modelfile=None, stream=False):
if (path := _as_path(path)) and path.exists():
modelfile = self._parse_modelfile(path.read_text(), base=path.parent)
elif modelfile:
modelfile = self._parse_modelfile(modelfile)
else:
raise Exception('must provide either path or modelfile')
fn = self._stream if stream else self._request_json
return fn('POST', '/api/create', json={
'model': model,
'modelfile': modelfile,
'stream': stream,
})
def _parse_modelfile(self, modelfile, base=None):
base = Path.cwd() if base is None else base
out = io.StringIO()
for line in io.StringIO(modelfile):
command, _, args = line.partition(' ')
if command.upper() in ['FROM', 'ADAPTER']:
path = Path(args).expanduser()
path = path if path.is_absolute() else base / path
if path.exists():
args = f'@{self._create_blob(path)}'
print(command, args, file=out)
return out.getvalue()
def _create_blob(self, path):
sha256sum = sha256()
with open(path, 'rb') as r:
while True:
chunk = r.read(32*1024)
if not chunk:
break
sha256sum.update(chunk)
digest = f'sha256:{sha256sum.hexdigest()}'
try:
self._request('HEAD', f'/api/blobs/{digest}')
except httpx.HTTPStatusError as e:
if e.response.status_code != 404:
raise
with open(path, 'rb') as r:
self._request('PUT', f'/api/blobs/{digest}', content=r)
return digest
def delete(self, model):
response = self._request_json('DELETE', '/api/delete', json={'model': model})
return {'status': 'success' if response.status_code == 200 else 'error'}
def list(self):
return self._request_json('GET', '/api/tags').get('models', [])
def copy(self, source, target):
response = self._request_json('POST', '/api/copy', json={'source': source, 'destination': target})
return {'status': 'success' if response.status_code == 200 else 'error'}
def show(self, model):
return self._request_json('GET', '/api/show', json={'model': model})
class AsyncClient(BaseClient):
def __init__(self, base='http://localhost:11434'):
super().__init__(httpx.AsyncClient, base)
async def _request(self, method, url, **kwargs):
response = await self._client.request(method, url, **kwargs)
response.raise_for_status()
return response
async def _request_json(self, method, url, **kwargs):
response = await self._request(method, url, **kwargs)
return response.json()
async def _stream(self, method, url, **kwargs):
async def inner():
async with self._client.stream(method, url, **kwargs) as r:
async for line in r.aiter_lines():
part = json.loads(line)
if e := part.get('error'):
raise Exception(e)
yield part
return inner()
async def generate(self, model='', prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None):
if not model:
raise Exception('must provide a model')
fn = self._stream if stream else self._request_json
return await fn('POST', '/api/generate', json={
'model': model,
'prompt': prompt,
'system': system,
'template': template,
'context': context or [],
'stream': stream,
'raw': raw,
'images': [_encode_image(image) for image in images or []],
'format': format,
'options': options or {},
})
async def chat(self, model='', messages=None, stream=False, format='', options=None):
if not model:
raise Exception('must provide a model')
for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of strings')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
if not message.get('content'):
raise Exception('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]
fn = self._stream if stream else self._request_json
return await fn('POST', '/api/chat', json={
'model': model,
'messages': messages,
'stream': stream,
'format': format,
'options': options or {},
})
async def pull(self, model, insecure=False, stream=False):
fn = self._stream if stream else self._request_json
return await fn('POST', '/api/pull', json={
'model': model,
'insecure': insecure,
'stream': stream,
})
async def push(self, model, insecure=False, stream=False):
fn = self._stream if stream else self._request_json
return await fn('POST', '/api/push', json={
'model': model,
'insecure': insecure,
'stream': stream,
})
async def create(self, model, path=None, modelfile=None, stream=False):
if (path := _as_path(path)) and path.exists():
modelfile = await self._parse_modelfile(path.read_text(), base=path.parent)
elif modelfile:
modelfile = await self._parse_modelfile(modelfile)
else:
raise Exception('must provide either path or modelfile')
fn = self._stream if stream else self._request_json
return await fn('POST', '/api/create', json={
'model': model,
'modelfile': modelfile,
'stream': stream,
})
async def _parse_modelfile(self, modelfile, base=None):
base = Path.cwd() if base is None else base
out = io.StringIO()
for line in io.StringIO(modelfile):
command, _, args = line.partition(' ')
if command.upper() in ['FROM', 'ADAPTER']:
path = Path(args).expanduser()
path = path if path.is_absolute() else base / path
if path.exists():
args = f'@{await self._create_blob(path)}'
print(command, args, file=out)
return out.getvalue()
async def _create_blob(self, path):
sha256sum = sha256()
with open(path, 'rb') as r:
while True:
chunk = r.read(32*1024)
if not chunk:
break
sha256sum.update(chunk)
digest = f'sha256:{sha256sum.hexdigest()}'
try:
await self._request('HEAD', f'/api/blobs/{digest}')
except httpx.HTTPStatusError as e:
if e.response.status_code != 404:
raise
async def upload_bytes():
with open(path, 'rb') as r:
while True:
chunk = r.read(32*1024)
if not chunk:
break
yield chunk
await self._request('PUT', f'/api/blobs/{digest}', content=upload_bytes())
return digest
async def delete(self, model):
response = await self._request_json('DELETE', '/api/delete', json={'model': model})
return {'status': 'success' if response.status_code == 200 else 'error'}
async def list(self):
response = await self._request_json('GET', '/api/tags')
return response.get('models', [])
async def copy(self, source, target):
response = await self._request_json('POST', '/api/copy', json={'source': source, 'destination': target})
return {'status': 'success' if response.status_code == 200 else 'error'}
async def show(self, model):
return await self._request_json('GET', '/api/show', json={'model': model})
def _encode_image(image):
if p := _as_path(image):
b64 = b64encode(p.read_bytes())
elif b := _as_bytesio(image):
b64 = b64encode(b.read())
else:
raise Exception('images must be a list of bytes, path-like objects, or file-like objects')
return b64.decode('utf-8')
def _as_path(s):
if isinstance(s, str) or isinstance(s, Path):
return Path(s)
return None
def _as_bytesio(s):
if isinstance(s, io.BytesIO):
return s
elif isinstance(s, bytes):
return io.BytesIO(s)
return None

View File

@ -1,188 +0,0 @@
import io
import json
import httpx
from pathlib import Path
from hashlib import sha256
from base64 import b64encode
class BaseClient:
def __init__(self, client, base_url='http://127.0.0.1:11434'):
self._client = client(base_url=base_url, follow_redirects=True, timeout=None)
class Client(BaseClient):
def __init__(self, base='http://localhost:11434'):
super().__init__(httpx.Client, base)
def _request(self, method, url, **kwargs):
response = self._client.request(method, url, **kwargs)
response.raise_for_status()
return response
def _request_json(self, method, url, **kwargs):
return self._request(method, url, **kwargs).json()
def stream(self, method, url, **kwargs):
with self._client.stream(method, url, **kwargs) as r:
for line in r.iter_lines():
part = json.loads(line)
if e := part.get('error'):
raise Exception(e)
yield part
def generate(self, model='', prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None):
if not model:
raise Exception('must provide a model')
fn = self.stream if stream else self._request_json
return fn('POST', '/api/generate', json={
'model': model,
'prompt': prompt,
'system': system,
'template': template,
'context': context or [],
'stream': stream,
'raw': raw,
'images': [_encode_image(image) for image in images or []],
'format': format,
'options': options or {},
})
def chat(self, model='', messages=None, stream=False, format='', options=None):
if not model:
raise Exception('must provide a model')
for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of strings')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
if not message.get('content'):
raise Exception('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]
fn = self.stream if stream else self._request_json
return fn('POST', '/api/chat', json={
'model': model,
'messages': messages,
'stream': stream,
'format': format,
'options': options or {},
})
def pull(self, model, insecure=False, stream=False):
fn = self.stream if stream else self._request_json
return fn('POST', '/api/pull', json={
'model': model,
'insecure': insecure,
'stream': stream,
})
def push(self, model, insecure=False, stream=False):
fn = self.stream if stream else self._request_json
return fn('POST', '/api/push', json={
'model': model,
'insecure': insecure,
'stream': stream,
})
def create(self, model, path=None, modelfile=None, stream=False):
if (path := _as_path(path)) and path.exists():
modelfile = _parse_modelfile(path.read_text(), self.create_blob, base=path.parent)
elif modelfile:
modelfile = _parse_modelfile(modelfile, self.create_blob)
else:
raise Exception('must provide either path or modelfile')
fn = self.stream if stream else self._request_json
return fn('POST', '/api/create', json={
'model': model,
'modelfile': modelfile,
'stream': stream,
})
def create_blob(self, path):
sha256sum = sha256()
with open(path, 'rb') as r:
while True:
chunk = r.read(32*1024)
if not chunk:
break
sha256sum.update(chunk)
digest = f'sha256:{sha256sum.hexdigest()}'
try:
self._request('HEAD', f'/api/blobs/{digest}')
except httpx.HTTPError:
with open(path, 'rb') as r:
self._request('PUT', f'/api/blobs/{digest}', content=r)
return digest
def delete(self, model):
response = self._request_json('DELETE', '/api/delete', json={'model': model})
return {'status': 'success' if response.status_code == 200 else 'error'}
def list(self):
return self._request_json('GET', '/api/tags').get('models', [])
def copy(self, source, target):
response = self._request_json('POST', '/api/copy', json={'source': source, 'destination': target})
return {'status': 'success' if response.status_code == 200 else 'error'}
def show(self, model):
return self._request_json('GET', '/api/show', json={'model': model}).json()
def _encode_image(image):
'''
_encode_images takes a list of images and returns a generator of base64 encoded images.
if the image is a string, it is assumed to be a path to a file.
if the image is a Path object, it is assumed to be a path to a file.
if the image is a bytes object, it is assumed to be the raw bytes of an image.
if the image is a file-like object, it is assumed to be a container to the raw bytes of an image.
'''
if p := _as_path(image):
b64 = b64encode(p.read_bytes())
elif b := _as_bytesio(image):
b64 = b64encode(b.read())
else:
raise Exception('images must be a list of bytes, path-like objects, or file-like objects')
return b64.decode('utf-8')
def _parse_modelfile(modelfile, cb, base=None):
base = Path.cwd() if base is None else base
out = io.StringIO()
for line in io.StringIO(modelfile):
command, _, args = line.partition(' ')
if command.upper() in ['FROM', 'ADAPTER']:
path = Path(args).expanduser()
path = path if path.is_absolute() else base / path
if path.exists():
args = f'@{cb(path)}'
print(command, args, file=out)
return out.getvalue()
def _as_path(s):
if isinstance(s, str) or isinstance(s, Path):
return Path(s)
return None
def _as_bytesio(s):
if isinstance(s, io.BytesIO):
return s
elif isinstance(s, bytes):
return io.BytesIO(s)
return None

View File

@ -1,292 +0,0 @@
import pytest
import os
import io
import types
import tempfile
from pathlib import Path
from ollama.client import Client
from pytest_httpserver import HTTPServer, URIPattern
from werkzeug.wrappers import Response
from PIL import Image
class PrefixPattern(URIPattern):
def __init__(self, prefix: str):
self.prefix = prefix
def match(self, uri):
return uri.startswith(self.prefix)
def test_client_chat(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/chat', method='POST', json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'stream': False,
'format': '',
'options': {},
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
assert isinstance(response, dict)
def test_client_chat_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/chat', method='POST', json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'stream': True,
'format': '',
'options': {},
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
assert isinstance(response, types.GeneratorType)
def test_client_chat_images(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/chat', method='POST', json={
'model': 'dummy',
'messages': [
{
'role': 'user',
'content': 'Why is the sky blue?',
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
},
],
'stream': False,
'format': '',
'options': {},
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
with io.BytesIO() as b:
Image.new('RGB', (1, 1)).save(b, 'PNG')
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
assert isinstance(response, dict)
def test_client_generate(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/generate', method='POST', json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': False,
'raw': False,
'images': [],
'format': '',
'options': {},
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.generate('dummy', 'Why is the sky blue?')
assert isinstance(response, dict)
def test_client_generate_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/generate', method='POST', json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': True,
'raw': False,
'images': [],
'format': '',
'options': {},
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.generate('dummy', 'Why is the sky blue?', stream=True)
assert isinstance(response, types.GeneratorType)
def test_client_generate_images(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/generate', method='POST', json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': False,
'raw': False,
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
'format': '',
'options': {},
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as temp:
Image.new('RGB', (1, 1)).save(temp, 'PNG')
response = client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
assert isinstance(response, dict)
def test_client_pull(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/pull', method='POST', json={
'model': 'dummy',
'insecure': False,
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.pull('dummy')
assert isinstance(response, dict)
def test_client_pull_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/pull', method='POST', json={
'model': 'dummy',
'insecure': False,
'stream': True,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.pull('dummy', stream=True)
assert isinstance(response, types.GeneratorType)
def test_client_push(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/push', method='POST', json={
'model': 'dummy',
'insecure': False,
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.push('dummy')
assert isinstance(response, dict)
def test_client_push_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/push', method='POST', json={
'model': 'dummy',
'insecure': False,
'stream': True,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.push('dummy', stream=True)
assert isinstance(response, types.GeneratorType)
def test_client_create_path(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile() as blob:
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
def test_client_create_path_relative(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
@pytest.fixture
def userhomedir():
with tempfile.TemporaryDirectory() as temp:
home = os.getenv('HOME', '')
os.environ['HOME'] = temp
yield Path(temp)
os.environ['HOME'] = home
def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
def test_client_create_modelfile(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = client.create('dummy', modelfile=f'FROM {blob.name}')
assert isinstance(response, dict)
def test_client_create_from_library(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM llama2\n',
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.create('dummy', modelfile='FROM llama2')
assert isinstance(response, dict)
def test_client_create_blob(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404))
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201))
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = client.create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
def test_client_create_blob_exists(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = client.create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'

20
poetry.lock generated
View File

@ -387,6 +387,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
[package.extras]
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-asyncio"
version = "0.23.2"
description = "Pytest support for asyncio"
optional = false
python-versions = ">=3.8"
files = [
{file = "pytest-asyncio-0.23.2.tar.gz", hash = "sha256:c16052382554c7b22d48782ab3438d5b10f8cf7a4bdcae7f0f67f097d95beecc"},
{file = "pytest_asyncio-0.23.2-py3-none-any.whl", hash = "sha256:ea9021364e32d58f0be43b91c6233fb8d2224ccef2398d6837559e587682808f"},
]
[package.dependencies]
pytest = ">=7.0.0"
[package.extras]
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
[[package]]
name = "pytest-cov"
version = "4.1.0"
@ -498,4 +516,4 @@ watchdog = ["watchdog (>=2.3)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.8"
content-hash = "b9f64e1a5795a417d2dbff7286360f8d3f8f10fdfa9580411940d144c2561e92"
content-hash = "9416a897c95d3c80cf1bfd3cc61cd19f0143c9bd0bc7c219fcb31ee27c497c9d"

View File

@ -11,6 +11,7 @@ httpx = "^0.25.2"
[tool.poetry.group.dev.dependencies]
pytest = "^7.4.3"
pytest-asyncio = "^0.23.2"
pytest-cov = "^4.1.0"
pytest-httpserver = "^1.0.8"
pillow = "^10.1.0"

636
tests/test_client.py Normal file
View File

@ -0,0 +1,636 @@
import os
import io
import json
import types
import pytest
import tempfile
from pathlib import Path
from pytest_httpserver import HTTPServer, URIPattern
from werkzeug.wrappers import Request, Response
from PIL import Image
from ollama._client import Client, AsyncClient
class PrefixPattern(URIPattern):
def __init__(self, prefix: str):
self.prefix = prefix
def match(self, uri):
return uri.startswith(self.prefix)
def test_client_chat(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/chat', method='POST', json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'stream': False,
'format': '',
'options': {},
}).respond_with_json({
'model': 'dummy',
'message': {
'role': 'assistant',
'content': "I don't know.",
},
})
client = Client(httpserver.url_for('/'))
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
assert response['model'] == 'dummy'
assert response['message']['role'] == 'assistant'
assert response['message']['content'] == "I don't know."
def test_client_chat_stream(httpserver: HTTPServer):
def stream_handler(_: Request):
def generate():
for message in ['I ', "don't ", 'know.']:
yield json.dumps({
'model': 'dummy',
'message': {
'role': 'assistant',
'content': message,
},
}) + '\n'
return Response(generate())
httpserver.expect_ordered_request('/api/chat', method='POST', json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'stream': True,
'format': '',
'options': {},
}).respond_with_handler(stream_handler)
client = Client(httpserver.url_for('/'))
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
for part in response:
assert part['message']['role'] in 'assistant'
assert part['message']['content'] in ['I ', "don't ", 'know.']
def test_client_chat_images(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/chat', method='POST', json={
'model': 'dummy',
'messages': [
{
'role': 'user',
'content': 'Why is the sky blue?',
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
},
],
'stream': False,
'format': '',
'options': {},
}).respond_with_json({
'model': 'dummy',
'message': {
'role': 'assistant',
'content': "I don't know.",
},
})
client = Client(httpserver.url_for('/'))
with io.BytesIO() as b:
Image.new('RGB', (1, 1)).save(b, 'PNG')
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
assert response['model'] == 'dummy'
assert response['message']['role'] == 'assistant'
assert response['message']['content'] == "I don't know."
def test_client_generate(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/generate', method='POST', json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': False,
'raw': False,
'images': [],
'format': '',
'options': {},
}).respond_with_json({
'model': 'dummy',
'response': 'Because it is.',
})
client = Client(httpserver.url_for('/'))
response = client.generate('dummy', 'Why is the sky blue?')
assert response['model'] == 'dummy'
assert response['response'] == 'Because it is.'
def test_client_generate_stream(httpserver: HTTPServer):
def stream_handler(_: Request):
def generate():
for message in ['Because ', 'it ', 'is.']:
yield json.dumps({
'model': 'dummy',
'response': message,
}) + '\n'
return Response(generate())
httpserver.expect_ordered_request('/api/generate', method='POST', json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': True,
'raw': False,
'images': [],
'format': '',
'options': {},
}).respond_with_handler(stream_handler)
client = Client(httpserver.url_for('/'))
response = client.generate('dummy', 'Why is the sky blue?', stream=True)
for part in response:
assert part['model'] == 'dummy'
assert part['response'] in ['Because ', 'it ', 'is.']
def test_client_generate_images(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/generate', method='POST', json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': False,
'raw': False,
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
'format': '',
'options': {},
}).respond_with_json({
'model': 'dummy',
'response': 'Because it is.',
})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as temp:
Image.new('RGB', (1, 1)).save(temp, 'PNG')
response = client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
assert response['model'] == 'dummy'
assert response['response'] == 'Because it is.'
def test_client_pull(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/pull', method='POST', json={
'model': 'dummy',
'insecure': False,
'stream': False,
}).respond_with_json({
'status': 'success',
})
client = Client(httpserver.url_for('/'))
response = client.pull('dummy')
assert response['status'] == 'success'
def test_client_pull_stream(httpserver: HTTPServer):
def stream_handler(_: Request):
def generate():
yield json.dumps({'status': 'pulling manifest'}) + '\n'
yield json.dumps({'status': 'verifying sha256 digest'}) + '\n'
yield json.dumps({'status': 'writing manifest'}) + '\n'
yield json.dumps({'status': 'removing any unused layers'}) + '\n'
yield json.dumps({'status': 'success'}) + '\n'
return Response(generate())
httpserver.expect_ordered_request('/api/pull', method='POST', json={
'model': 'dummy',
'insecure': False,
'stream': True,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.pull('dummy', stream=True)
assert isinstance(response, types.GeneratorType)
def test_client_push(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/push', method='POST', json={
'model': 'dummy',
'insecure': False,
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.push('dummy')
assert isinstance(response, dict)
def test_client_push_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/push', method='POST', json={
'model': 'dummy',
'insecure': False,
'stream': True,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.push('dummy', stream=True)
assert isinstance(response, types.GeneratorType)
def test_client_create_path(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile() as blob:
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
def test_client_create_path_relative(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
@pytest.fixture
def userhomedir():
with tempfile.TemporaryDirectory() as temp:
home = os.getenv('HOME', '')
os.environ['HOME'] = temp
yield Path(temp)
os.environ['HOME'] = home
def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
def test_client_create_modelfile(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = client.create('dummy', modelfile=f'FROM {blob.name}')
assert isinstance(response, dict)
def test_client_create_from_library(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM llama2\n',
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.create('dummy', modelfile='FROM llama2')
assert isinstance(response, dict)
def test_client_create_blob(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404))
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201))
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = client._create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
def test_client_create_blob_exists(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = client._create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
@pytest.mark.asyncio
async def test_async_client_chat(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/chat', method='POST', json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'stream': False,
'format': '',
'options': {},
}).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_chat_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/chat', method='POST', json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'stream': True,
'format': '',
'options': {},
}).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
assert isinstance(response, types.AsyncGeneratorType)
@pytest.mark.asyncio
async def test_async_client_chat_images(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/chat', method='POST', json={
'model': 'dummy',
'messages': [
{
'role': 'user',
'content': 'Why is the sky blue?',
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
},
],
'stream': False,
'format': '',
'options': {},
}).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
with io.BytesIO() as b:
Image.new('RGB', (1, 1)).save(b, 'PNG')
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_generate(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/generate', method='POST', json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': False,
'raw': False,
'images': [],
'format': '',
'options': {},
}).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.generate('dummy', 'Why is the sky blue?')
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_generate_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/generate', method='POST', json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': True,
'raw': False,
'images': [],
'format': '',
'options': {},
}).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.generate('dummy', 'Why is the sky blue?', stream=True)
assert isinstance(response, types.AsyncGeneratorType)
@pytest.mark.asyncio
async def test_async_client_generate_images(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/generate', method='POST', json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': False,
'raw': False,
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
'format': '',
'options': {},
}).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as temp:
Image.new('RGB', (1, 1)).save(temp, 'PNG')
response = await client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_pull(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/pull', method='POST', json={
'model': 'dummy',
'insecure': False,
'stream': False,
}).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.pull('dummy')
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_pull_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/pull', method='POST', json={
'model': 'dummy',
'insecure': False,
'stream': True,
}).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.pull('dummy', stream=True)
assert isinstance(response, types.AsyncGeneratorType)
@pytest.mark.asyncio
async def test_async_client_push(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/push', method='POST', json={
'model': 'dummy',
'insecure': False,
'stream': False,
}).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.push('dummy')
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_push_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/push', method='POST', json={
'model': 'dummy',
'insecure': False,
'stream': True,
}).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.push('dummy', stream=True)
assert isinstance(response, types.AsyncGeneratorType)
@pytest.mark.asyncio
async def test_async_client_create_path(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
}).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile() as blob:
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
modelfile.flush()
response = await client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_create_path_relative(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
}).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = await client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
}).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = await client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_create_modelfile(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
}).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = await client.create('dummy', modelfile=f'FROM {blob.name}')
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_create_from_library(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM llama2\n',
'stream': False,
}).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.create('dummy', modelfile='FROM llama2')
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_create_blob(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404))
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201))
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = await client._create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
@pytest.mark.asyncio
async def test_async_client_create_blob_exists(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = await client._create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'