mirror of
https://github.com/ollama/ollama-python.git
synced 2026-01-13 21:57:16 +08:00
async client
This commit is contained in:
parent
21578e3c5e
commit
20db23d932
@ -1,4 +1,19 @@
|
||||
from ollama.client import Client
|
||||
from ._client import Client, AsyncClient
|
||||
|
||||
__all__ = [
|
||||
'Client',
|
||||
'AsyncClient',
|
||||
'generate',
|
||||
'chat',
|
||||
'pull',
|
||||
'push',
|
||||
'create',
|
||||
'delete',
|
||||
'list',
|
||||
'copy',
|
||||
'show',
|
||||
]
|
||||
|
||||
|
||||
_default_client = Client()
|
||||
|
||||
|
||||
337
ollama/_client.py
Normal file
337
ollama/_client.py
Normal file
@ -0,0 +1,337 @@
|
||||
import io
|
||||
import json
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from hashlib import sha256
|
||||
from base64 import b64encode
|
||||
|
||||
|
||||
class BaseClient:
|
||||
|
||||
def __init__(self, client, base_url='http://127.0.0.1:11434'):
|
||||
self._client = client(base_url=base_url, follow_redirects=True, timeout=None)
|
||||
|
||||
|
||||
class Client(BaseClient):
|
||||
|
||||
def __init__(self, base='http://localhost:11434'):
|
||||
super().__init__(httpx.Client, base)
|
||||
|
||||
def _request(self, method, url, **kwargs):
|
||||
response = self._client.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _request_json(self, method, url, **kwargs):
|
||||
return self._request(method, url, **kwargs).json()
|
||||
|
||||
def _stream(self, method, url, **kwargs):
|
||||
with self._client.stream(method, url, **kwargs) as r:
|
||||
for line in r.iter_lines():
|
||||
part = json.loads(line)
|
||||
if e := part.get('error'):
|
||||
raise Exception(e)
|
||||
yield part
|
||||
|
||||
def generate(self, model='', prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None):
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return fn('POST', '/api/generate', json={
|
||||
'model': model,
|
||||
'prompt': prompt,
|
||||
'system': system,
|
||||
'template': template,
|
||||
'context': context or [],
|
||||
'stream': stream,
|
||||
'raw': raw,
|
||||
'images': [_encode_image(image) for image in images or []],
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
})
|
||||
|
||||
def chat(self, model='', messages=None, stream=False, format='', options=None):
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
for message in messages or []:
|
||||
if not isinstance(message, dict):
|
||||
raise TypeError('messages must be a list of strings')
|
||||
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
|
||||
raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
|
||||
if not message.get('content'):
|
||||
raise Exception('messages must contain content')
|
||||
if images := message.get('images'):
|
||||
message['images'] = [_encode_image(image) for image in images]
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return fn('POST', '/api/chat', json={
|
||||
'model': model,
|
||||
'messages': messages,
|
||||
'stream': stream,
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
})
|
||||
|
||||
def pull(self, model, insecure=False, stream=False):
|
||||
fn = self._stream if stream else self._request_json
|
||||
return fn('POST', '/api/pull', json={
|
||||
'model': model,
|
||||
'insecure': insecure,
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
def push(self, model, insecure=False, stream=False):
|
||||
fn = self._stream if stream else self._request_json
|
||||
return fn('POST', '/api/push', json={
|
||||
'model': model,
|
||||
'insecure': insecure,
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
def create(self, model, path=None, modelfile=None, stream=False):
|
||||
if (path := _as_path(path)) and path.exists():
|
||||
modelfile = self._parse_modelfile(path.read_text(), base=path.parent)
|
||||
elif modelfile:
|
||||
modelfile = self._parse_modelfile(modelfile)
|
||||
else:
|
||||
raise Exception('must provide either path or modelfile')
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return fn('POST', '/api/create', json={
|
||||
'model': model,
|
||||
'modelfile': modelfile,
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
def _parse_modelfile(self, modelfile, base=None):
|
||||
base = Path.cwd() if base is None else base
|
||||
|
||||
out = io.StringIO()
|
||||
for line in io.StringIO(modelfile):
|
||||
command, _, args = line.partition(' ')
|
||||
if command.upper() in ['FROM', 'ADAPTER']:
|
||||
path = Path(args).expanduser()
|
||||
path = path if path.is_absolute() else base / path
|
||||
if path.exists():
|
||||
args = f'@{self._create_blob(path)}'
|
||||
|
||||
print(command, args, file=out)
|
||||
return out.getvalue()
|
||||
|
||||
def _create_blob(self, path):
|
||||
sha256sum = sha256()
|
||||
with open(path, 'rb') as r:
|
||||
while True:
|
||||
chunk = r.read(32*1024)
|
||||
if not chunk:
|
||||
break
|
||||
sha256sum.update(chunk)
|
||||
|
||||
digest = f'sha256:{sha256sum.hexdigest()}'
|
||||
|
||||
try:
|
||||
self._request('HEAD', f'/api/blobs/{digest}')
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code != 404:
|
||||
raise
|
||||
|
||||
with open(path, 'rb') as r:
|
||||
self._request('PUT', f'/api/blobs/{digest}', content=r)
|
||||
|
||||
return digest
|
||||
|
||||
def delete(self, model):
|
||||
response = self._request_json('DELETE', '/api/delete', json={'model': model})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
def list(self):
|
||||
return self._request_json('GET', '/api/tags').get('models', [])
|
||||
|
||||
def copy(self, source, target):
|
||||
response = self._request_json('POST', '/api/copy', json={'source': source, 'destination': target})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
def show(self, model):
|
||||
return self._request_json('GET', '/api/show', json={'model': model})
|
||||
|
||||
|
||||
class AsyncClient(BaseClient):
|
||||
|
||||
def __init__(self, base='http://localhost:11434'):
|
||||
super().__init__(httpx.AsyncClient, base)
|
||||
|
||||
async def _request(self, method, url, **kwargs):
|
||||
response = await self._client.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
async def _request_json(self, method, url, **kwargs):
|
||||
response = await self._request(method, url, **kwargs)
|
||||
return response.json()
|
||||
|
||||
async def _stream(self, method, url, **kwargs):
|
||||
async def inner():
|
||||
async with self._client.stream(method, url, **kwargs) as r:
|
||||
async for line in r.aiter_lines():
|
||||
part = json.loads(line)
|
||||
if e := part.get('error'):
|
||||
raise Exception(e)
|
||||
yield part
|
||||
return inner()
|
||||
|
||||
async def generate(self, model='', prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None):
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return await fn('POST', '/api/generate', json={
|
||||
'model': model,
|
||||
'prompt': prompt,
|
||||
'system': system,
|
||||
'template': template,
|
||||
'context': context or [],
|
||||
'stream': stream,
|
||||
'raw': raw,
|
||||
'images': [_encode_image(image) for image in images or []],
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
})
|
||||
|
||||
async def chat(self, model='', messages=None, stream=False, format='', options=None):
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
for message in messages or []:
|
||||
if not isinstance(message, dict):
|
||||
raise TypeError('messages must be a list of strings')
|
||||
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
|
||||
raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
|
||||
if not message.get('content'):
|
||||
raise Exception('messages must contain content')
|
||||
if images := message.get('images'):
|
||||
message['images'] = [_encode_image(image) for image in images]
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return await fn('POST', '/api/chat', json={
|
||||
'model': model,
|
||||
'messages': messages,
|
||||
'stream': stream,
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
})
|
||||
|
||||
async def pull(self, model, insecure=False, stream=False):
|
||||
fn = self._stream if stream else self._request_json
|
||||
return await fn('POST', '/api/pull', json={
|
||||
'model': model,
|
||||
'insecure': insecure,
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
async def push(self, model, insecure=False, stream=False):
|
||||
fn = self._stream if stream else self._request_json
|
||||
return await fn('POST', '/api/push', json={
|
||||
'model': model,
|
||||
'insecure': insecure,
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
async def create(self, model, path=None, modelfile=None, stream=False):
|
||||
if (path := _as_path(path)) and path.exists():
|
||||
modelfile = await self._parse_modelfile(path.read_text(), base=path.parent)
|
||||
elif modelfile:
|
||||
modelfile = await self._parse_modelfile(modelfile)
|
||||
else:
|
||||
raise Exception('must provide either path or modelfile')
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return await fn('POST', '/api/create', json={
|
||||
'model': model,
|
||||
'modelfile': modelfile,
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
async def _parse_modelfile(self, modelfile, base=None):
|
||||
base = Path.cwd() if base is None else base
|
||||
|
||||
out = io.StringIO()
|
||||
for line in io.StringIO(modelfile):
|
||||
command, _, args = line.partition(' ')
|
||||
if command.upper() in ['FROM', 'ADAPTER']:
|
||||
path = Path(args).expanduser()
|
||||
path = path if path.is_absolute() else base / path
|
||||
if path.exists():
|
||||
args = f'@{await self._create_blob(path)}'
|
||||
|
||||
print(command, args, file=out)
|
||||
return out.getvalue()
|
||||
|
||||
async def _create_blob(self, path):
|
||||
sha256sum = sha256()
|
||||
with open(path, 'rb') as r:
|
||||
while True:
|
||||
chunk = r.read(32*1024)
|
||||
if not chunk:
|
||||
break
|
||||
sha256sum.update(chunk)
|
||||
|
||||
digest = f'sha256:{sha256sum.hexdigest()}'
|
||||
|
||||
try:
|
||||
await self._request('HEAD', f'/api/blobs/{digest}')
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code != 404:
|
||||
raise
|
||||
|
||||
async def upload_bytes():
|
||||
with open(path, 'rb') as r:
|
||||
while True:
|
||||
chunk = r.read(32*1024)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
await self._request('PUT', f'/api/blobs/{digest}', content=upload_bytes())
|
||||
|
||||
return digest
|
||||
|
||||
async def delete(self, model):
|
||||
response = await self._request_json('DELETE', '/api/delete', json={'model': model})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
async def list(self):
|
||||
response = await self._request_json('GET', '/api/tags')
|
||||
return response.get('models', [])
|
||||
|
||||
async def copy(self, source, target):
|
||||
response = await self._request_json('POST', '/api/copy', json={'source': source, 'destination': target})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
async def show(self, model):
|
||||
return await self._request_json('GET', '/api/show', json={'model': model})
|
||||
|
||||
|
||||
def _encode_image(image):
|
||||
if p := _as_path(image):
|
||||
b64 = b64encode(p.read_bytes())
|
||||
elif b := _as_bytesio(image):
|
||||
b64 = b64encode(b.read())
|
||||
else:
|
||||
raise Exception('images must be a list of bytes, path-like objects, or file-like objects')
|
||||
|
||||
return b64.decode('utf-8')
|
||||
|
||||
|
||||
def _as_path(s):
|
||||
if isinstance(s, str) or isinstance(s, Path):
|
||||
return Path(s)
|
||||
return None
|
||||
|
||||
def _as_bytesio(s):
|
||||
if isinstance(s, io.BytesIO):
|
||||
return s
|
||||
elif isinstance(s, bytes):
|
||||
return io.BytesIO(s)
|
||||
return None
|
||||
188
ollama/client.py
188
ollama/client.py
@ -1,188 +0,0 @@
|
||||
import io
|
||||
import json
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from hashlib import sha256
|
||||
from base64 import b64encode
|
||||
|
||||
|
||||
class BaseClient:
|
||||
|
||||
def __init__(self, client, base_url='http://127.0.0.1:11434'):
|
||||
self._client = client(base_url=base_url, follow_redirects=True, timeout=None)
|
||||
|
||||
|
||||
class Client(BaseClient):
|
||||
|
||||
def __init__(self, base='http://localhost:11434'):
|
||||
super().__init__(httpx.Client, base)
|
||||
|
||||
def _request(self, method, url, **kwargs):
|
||||
response = self._client.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _request_json(self, method, url, **kwargs):
|
||||
return self._request(method, url, **kwargs).json()
|
||||
|
||||
def stream(self, method, url, **kwargs):
|
||||
with self._client.stream(method, url, **kwargs) as r:
|
||||
for line in r.iter_lines():
|
||||
part = json.loads(line)
|
||||
if e := part.get('error'):
|
||||
raise Exception(e)
|
||||
yield part
|
||||
|
||||
def generate(self, model='', prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None):
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
fn = self.stream if stream else self._request_json
|
||||
return fn('POST', '/api/generate', json={
|
||||
'model': model,
|
||||
'prompt': prompt,
|
||||
'system': system,
|
||||
'template': template,
|
||||
'context': context or [],
|
||||
'stream': stream,
|
||||
'raw': raw,
|
||||
'images': [_encode_image(image) for image in images or []],
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
})
|
||||
|
||||
def chat(self, model='', messages=None, stream=False, format='', options=None):
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
for message in messages or []:
|
||||
if not isinstance(message, dict):
|
||||
raise TypeError('messages must be a list of strings')
|
||||
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
|
||||
raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
|
||||
if not message.get('content'):
|
||||
raise Exception('messages must contain content')
|
||||
if images := message.get('images'):
|
||||
message['images'] = [_encode_image(image) for image in images]
|
||||
|
||||
fn = self.stream if stream else self._request_json
|
||||
return fn('POST', '/api/chat', json={
|
||||
'model': model,
|
||||
'messages': messages,
|
||||
'stream': stream,
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
})
|
||||
|
||||
def pull(self, model, insecure=False, stream=False):
|
||||
fn = self.stream if stream else self._request_json
|
||||
return fn('POST', '/api/pull', json={
|
||||
'model': model,
|
||||
'insecure': insecure,
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
def push(self, model, insecure=False, stream=False):
|
||||
fn = self.stream if stream else self._request_json
|
||||
return fn('POST', '/api/push', json={
|
||||
'model': model,
|
||||
'insecure': insecure,
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
def create(self, model, path=None, modelfile=None, stream=False):
|
||||
if (path := _as_path(path)) and path.exists():
|
||||
modelfile = _parse_modelfile(path.read_text(), self.create_blob, base=path.parent)
|
||||
elif modelfile:
|
||||
modelfile = _parse_modelfile(modelfile, self.create_blob)
|
||||
else:
|
||||
raise Exception('must provide either path or modelfile')
|
||||
|
||||
fn = self.stream if stream else self._request_json
|
||||
return fn('POST', '/api/create', json={
|
||||
'model': model,
|
||||
'modelfile': modelfile,
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
|
||||
def create_blob(self, path):
|
||||
sha256sum = sha256()
|
||||
with open(path, 'rb') as r:
|
||||
while True:
|
||||
chunk = r.read(32*1024)
|
||||
if not chunk:
|
||||
break
|
||||
sha256sum.update(chunk)
|
||||
|
||||
digest = f'sha256:{sha256sum.hexdigest()}'
|
||||
|
||||
try:
|
||||
self._request('HEAD', f'/api/blobs/{digest}')
|
||||
except httpx.HTTPError:
|
||||
with open(path, 'rb') as r:
|
||||
self._request('PUT', f'/api/blobs/{digest}', content=r)
|
||||
|
||||
return digest
|
||||
|
||||
def delete(self, model):
|
||||
response = self._request_json('DELETE', '/api/delete', json={'model': model})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
def list(self):
|
||||
return self._request_json('GET', '/api/tags').get('models', [])
|
||||
|
||||
def copy(self, source, target):
|
||||
response = self._request_json('POST', '/api/copy', json={'source': source, 'destination': target})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
def show(self, model):
|
||||
return self._request_json('GET', '/api/show', json={'model': model}).json()
|
||||
|
||||
|
||||
def _encode_image(image):
|
||||
'''
|
||||
_encode_images takes a list of images and returns a generator of base64 encoded images.
|
||||
if the image is a string, it is assumed to be a path to a file.
|
||||
if the image is a Path object, it is assumed to be a path to a file.
|
||||
if the image is a bytes object, it is assumed to be the raw bytes of an image.
|
||||
if the image is a file-like object, it is assumed to be a container to the raw bytes of an image.
|
||||
'''
|
||||
|
||||
if p := _as_path(image):
|
||||
b64 = b64encode(p.read_bytes())
|
||||
elif b := _as_bytesio(image):
|
||||
b64 = b64encode(b.read())
|
||||
else:
|
||||
raise Exception('images must be a list of bytes, path-like objects, or file-like objects')
|
||||
|
||||
return b64.decode('utf-8')
|
||||
|
||||
|
||||
def _parse_modelfile(modelfile, cb, base=None):
|
||||
base = Path.cwd() if base is None else base
|
||||
|
||||
out = io.StringIO()
|
||||
for line in io.StringIO(modelfile):
|
||||
command, _, args = line.partition(' ')
|
||||
if command.upper() in ['FROM', 'ADAPTER']:
|
||||
path = Path(args).expanduser()
|
||||
path = path if path.is_absolute() else base / path
|
||||
if path.exists():
|
||||
args = f'@{cb(path)}'
|
||||
|
||||
print(command, args, file=out)
|
||||
return out.getvalue()
|
||||
|
||||
|
||||
def _as_path(s):
|
||||
if isinstance(s, str) or isinstance(s, Path):
|
||||
return Path(s)
|
||||
return None
|
||||
|
||||
def _as_bytesio(s):
|
||||
if isinstance(s, io.BytesIO):
|
||||
return s
|
||||
elif isinstance(s, bytes):
|
||||
return io.BytesIO(s)
|
||||
return None
|
||||
@ -1,292 +0,0 @@
|
||||
import pytest
|
||||
import os
|
||||
import io
|
||||
import types
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from ollama.client import Client
|
||||
from pytest_httpserver import HTTPServer, URIPattern
|
||||
from werkzeug.wrappers import Response
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class PrefixPattern(URIPattern):
|
||||
def __init__(self, prefix: str):
|
||||
self.prefix = prefix
|
||||
|
||||
def match(self, uri):
|
||||
return uri.startswith(self.prefix)
|
||||
|
||||
|
||||
def test_client_chat(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/chat', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_chat_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/chat', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'stream': True,
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
|
||||
assert isinstance(response, types.GeneratorType)
|
||||
|
||||
|
||||
def test_client_chat_images(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/chat', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'messages': [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Why is the sky blue?',
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
},
|
||||
],
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with io.BytesIO() as b:
|
||||
Image.new('RGB', (1, 1)).save(b, 'PNG')
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_generate(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/generate', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': False,
|
||||
'raw': False,
|
||||
'images': [],
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy', 'Why is the sky blue?')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_generate_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/generate', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': True,
|
||||
'raw': False,
|
||||
'images': [],
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy', 'Why is the sky blue?', stream=True)
|
||||
assert isinstance(response, types.GeneratorType)
|
||||
|
||||
|
||||
def test_client_generate_images(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/generate', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': False,
|
||||
'raw': False,
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
Image.new('RGB', (1, 1)).save(temp, 'PNG')
|
||||
response = client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_pull(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/pull', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.pull('dummy')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_pull_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/pull', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': True,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.pull('dummy', stream=True)
|
||||
assert isinstance(response, types.GeneratorType)
|
||||
|
||||
|
||||
def test_client_push(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/push', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.push('dummy')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_push_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/push', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': True,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.push('dummy', stream=True)
|
||||
assert isinstance(response, types.GeneratorType)
|
||||
|
||||
|
||||
def test_client_create_path(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_path_relative(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
|
||||
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def userhomedir():
|
||||
with tempfile.TemporaryDirectory() as temp:
|
||||
home = os.getenv('HOME', '')
|
||||
os.environ['HOME'] = temp
|
||||
yield Path(temp)
|
||||
os.environ['HOME'] = home
|
||||
|
||||
|
||||
def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
|
||||
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_modelfile(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = client.create('dummy', modelfile=f'FROM {blob.name}')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_from_library(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM llama2\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
response = client.create('dummy', modelfile='FROM llama2')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_blob(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404))
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201))
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = client.create_blob(blob.name)
|
||||
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
|
||||
|
||||
|
||||
def test_client_create_blob_exists(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = client.create_blob(blob.name)
|
||||
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
|
||||
20
poetry.lock
generated
20
poetry.lock
generated
@ -387,6 +387,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
|
||||
[package.extras]
|
||||
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
version = "0.23.2"
|
||||
description = "Pytest support for asyncio"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pytest-asyncio-0.23.2.tar.gz", hash = "sha256:c16052382554c7b22d48782ab3438d5b10f8cf7a4bdcae7f0f67f097d95beecc"},
|
||||
{file = "pytest_asyncio-0.23.2-py3-none-any.whl", hash = "sha256:ea9021364e32d58f0be43b91c6233fb8d2224ccef2398d6837559e587682808f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=7.0.0"
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
|
||||
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "4.1.0"
|
||||
@ -498,4 +516,4 @@ watchdog = ["watchdog (>=2.3)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.8"
|
||||
content-hash = "b9f64e1a5795a417d2dbff7286360f8d3f8f10fdfa9580411940d144c2561e92"
|
||||
content-hash = "9416a897c95d3c80cf1bfd3cc61cd19f0143c9bd0bc7c219fcb31ee27c497c9d"
|
||||
|
||||
@ -11,6 +11,7 @@ httpx = "^0.25.2"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^7.4.3"
|
||||
pytest-asyncio = "^0.23.2"
|
||||
pytest-cov = "^4.1.0"
|
||||
pytest-httpserver = "^1.0.8"
|
||||
pillow = "^10.1.0"
|
||||
|
||||
636
tests/test_client.py
Normal file
636
tests/test_client.py
Normal file
@ -0,0 +1,636 @@
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
import types
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from pytest_httpserver import HTTPServer, URIPattern
|
||||
from werkzeug.wrappers import Request, Response
|
||||
from PIL import Image
|
||||
|
||||
from ollama._client import Client, AsyncClient
|
||||
|
||||
|
||||
class PrefixPattern(URIPattern):
|
||||
def __init__(self, prefix: str):
|
||||
self.prefix = prefix
|
||||
|
||||
def match(self, uri):
|
||||
return uri.startswith(self.prefix)
|
||||
|
||||
|
||||
def test_client_chat(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/chat', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({
|
||||
'model': 'dummy',
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': "I don't know.",
|
||||
},
|
||||
})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['message']['role'] == 'assistant'
|
||||
assert response['message']['content'] == "I don't know."
|
||||
|
||||
|
||||
def test_client_chat_stream(httpserver: HTTPServer):
|
||||
def stream_handler(_: Request):
|
||||
def generate():
|
||||
for message in ['I ', "don't ", 'know.']:
|
||||
yield json.dumps({
|
||||
'model': 'dummy',
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': message,
|
||||
},
|
||||
}) + '\n'
|
||||
return Response(generate())
|
||||
|
||||
httpserver.expect_ordered_request('/api/chat', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'stream': True,
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_handler(stream_handler)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
|
||||
for part in response:
|
||||
assert part['message']['role'] in 'assistant'
|
||||
assert part['message']['content'] in ['I ', "don't ", 'know.']
|
||||
|
||||
|
||||
def test_client_chat_images(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/chat', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'messages': [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Why is the sky blue?',
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
},
|
||||
],
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({
|
||||
'model': 'dummy',
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': "I don't know.",
|
||||
},
|
||||
})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with io.BytesIO() as b:
|
||||
Image.new('RGB', (1, 1)).save(b, 'PNG')
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['message']['role'] == 'assistant'
|
||||
assert response['message']['content'] == "I don't know."
|
||||
|
||||
|
||||
def test_client_generate(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/generate', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': False,
|
||||
'raw': False,
|
||||
'images': [],
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({
|
||||
'model': 'dummy',
|
||||
'response': 'Because it is.',
|
||||
})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy', 'Why is the sky blue?')
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['response'] == 'Because it is.'
|
||||
|
||||
|
||||
def test_client_generate_stream(httpserver: HTTPServer):
|
||||
def stream_handler(_: Request):
|
||||
def generate():
|
||||
for message in ['Because ', 'it ', 'is.']:
|
||||
yield json.dumps({
|
||||
'model': 'dummy',
|
||||
'response': message,
|
||||
}) + '\n'
|
||||
return Response(generate())
|
||||
|
||||
httpserver.expect_ordered_request('/api/generate', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': True,
|
||||
'raw': False,
|
||||
'images': [],
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_handler(stream_handler)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy', 'Why is the sky blue?', stream=True)
|
||||
for part in response:
|
||||
assert part['model'] == 'dummy'
|
||||
assert part['response'] in ['Because ', 'it ', 'is.']
|
||||
|
||||
|
||||
def test_client_generate_images(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/generate', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': False,
|
||||
'raw': False,
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({
|
||||
'model': 'dummy',
|
||||
'response': 'Because it is.',
|
||||
})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
Image.new('RGB', (1, 1)).save(temp, 'PNG')
|
||||
response = client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['response'] == 'Because it is.'
|
||||
|
||||
|
||||
def test_client_pull(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/pull', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': False,
|
||||
}).respond_with_json({
|
||||
'status': 'success',
|
||||
})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.pull('dummy')
|
||||
assert response['status'] == 'success'
|
||||
|
||||
|
||||
def test_client_pull_stream(httpserver: HTTPServer):
|
||||
def stream_handler(_: Request):
|
||||
def generate():
|
||||
yield json.dumps({'status': 'pulling manifest'}) + '\n'
|
||||
yield json.dumps({'status': 'verifying sha256 digest'}) + '\n'
|
||||
yield json.dumps({'status': 'writing manifest'}) + '\n'
|
||||
yield json.dumps({'status': 'removing any unused layers'}) + '\n'
|
||||
yield json.dumps({'status': 'success'}) + '\n'
|
||||
return Response(generate())
|
||||
|
||||
httpserver.expect_ordered_request('/api/pull', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': True,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.pull('dummy', stream=True)
|
||||
assert isinstance(response, types.GeneratorType)
|
||||
|
||||
|
||||
def test_client_push(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/push', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.push('dummy')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_push_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/push', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': True,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.push('dummy', stream=True)
|
||||
assert isinstance(response, types.GeneratorType)
|
||||
|
||||
|
||||
def test_client_create_path(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_path_relative(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
|
||||
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def userhomedir():
|
||||
with tempfile.TemporaryDirectory() as temp:
|
||||
home = os.getenv('HOME', '')
|
||||
os.environ['HOME'] = temp
|
||||
yield Path(temp)
|
||||
os.environ['HOME'] = home
|
||||
|
||||
|
||||
def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
|
||||
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_modelfile(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = client.create('dummy', modelfile=f'FROM {blob.name}')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_from_library(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM llama2\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
response = client.create('dummy', modelfile='FROM llama2')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_blob(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404))
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201))
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = client._create_blob(blob.name)
|
||||
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
|
||||
|
||||
|
||||
def test_client_create_blob_exists(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = client._create_blob(blob.name)
|
||||
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_chat(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/chat', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_chat_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/chat', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'stream': True,
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
|
||||
assert isinstance(response, types.AsyncGeneratorType)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_chat_images(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/chat', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'messages': [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Why is the sky blue?',
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
},
|
||||
],
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with io.BytesIO() as b:
|
||||
Image.new('RGB', (1, 1)).save(b, 'PNG')
|
||||
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_generate(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/generate', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': False,
|
||||
'raw': False,
|
||||
'images': [],
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.generate('dummy', 'Why is the sky blue?')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_generate_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/generate', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': True,
|
||||
'raw': False,
|
||||
'images': [],
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.generate('dummy', 'Why is the sky blue?', stream=True)
|
||||
assert isinstance(response, types.AsyncGeneratorType)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_generate_images(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/generate', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': False,
|
||||
'raw': False,
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
Image.new('RGB', (1, 1)).save(temp, 'PNG')
|
||||
response = await client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_pull(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/pull', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.pull('dummy')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_pull_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/pull', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': True,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.pull('dummy', stream=True)
|
||||
assert isinstance(response, types.AsyncGeneratorType)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_push(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/push', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.push('dummy')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_push_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/push', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': True,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.push('dummy', stream=True)
|
||||
assert isinstance(response, types.AsyncGeneratorType)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_path(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = await client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_path_relative(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
|
||||
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = await client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
|
||||
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = await client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_modelfile(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = await client.create('dummy', modelfile=f'FROM {blob.name}')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_from_library(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM llama2\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
response = await client.create('dummy', modelfile='FROM llama2')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_blob(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404))
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201))
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = await client._create_blob(blob.name)
|
||||
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_blob_exists(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = await client._create_blob(blob.name)
|
||||
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
|
||||
Loading…
Reference in New Issue
Block a user