This commit is contained in:
Michael Yang 2023-12-21 14:21:02 -08:00
parent 9d93f70806
commit 47c934c74b
5 changed files with 576 additions and 392 deletions

View File

@ -3,9 +3,9 @@ from ollama import generate
prefix = '''def remove_non_ascii(s: str) -> str: prefix = '''def remove_non_ascii(s: str) -> str:
""" ''' """ '''
suffix = ''' suffix = """
return result return result
''' """
response = generate( response = generate(

View File

@ -9,6 +9,7 @@ from base64 import b64encode
from typing import Any, AnyStr, Union, Optional, List, Mapping from typing import Any, AnyStr, Union, Optional, List, Mapping
import sys import sys
if sys.version_info < (3, 9): if sys.version_info < (3, 9):
from typing import Iterator, AsyncIterator from typing import Iterator, AsyncIterator
else: else:
@ -18,13 +19,11 @@ from ollama._types import Message, Options
class BaseClient: class BaseClient:
def __init__(self, client, base_url='http://127.0.0.1:11434') -> None: def __init__(self, client, base_url='http://127.0.0.1:11434') -> None:
self._client = client(base_url=base_url, follow_redirects=True, timeout=None) self._client = client(base_url=base_url, follow_redirects=True, timeout=None)
class Client(BaseClient): class Client(BaseClient):
def __init__(self, base='http://localhost:11434') -> None: def __init__(self, base='http://localhost:11434') -> None:
super().__init__(httpx.Client, base) super().__init__(httpx.Client, base)
@ -45,43 +44,47 @@ class Client(BaseClient):
yield part yield part
def generate( def generate(
self, self,
model: str = '', model: str = '',
prompt: str = '', prompt: str = '',
system: str = '', system: str = '',
template: str = '', template: str = '',
context: Optional[List[int]] = None, context: Optional[List[int]] = None,
stream: bool = False, stream: bool = False,
raw: bool = False, raw: bool = False,
format: str = '', format: str = '',
images: Optional[List[AnyStr]] = None, images: Optional[List[AnyStr]] = None,
options: Optional[Options] = None, options: Optional[Options] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
if not model: if not model:
raise Exception('must provide a model') raise Exception('must provide a model')
fn = self._stream if stream else self._request_json fn = self._stream if stream else self._request_json
return fn('POST', '/api/generate', json={ return fn(
'model': model, 'POST',
'prompt': prompt, '/api/generate',
'system': system, json={
'template': template, 'model': model,
'context': context or [], 'prompt': prompt,
'stream': stream, 'system': system,
'raw': raw, 'template': template,
'images': [_encode_image(image) for image in images or []], 'context': context or [],
'format': format, 'stream': stream,
'options': options or {}, 'raw': raw,
}) 'images': [_encode_image(image) for image in images or []],
'format': format,
'options': options or {},
},
)
def chat( def chat(
self, self,
model: str = '', model: str = '',
messages: Optional[List[Message]] = None, messages: Optional[List[Message]] = None,
stream: bool = False, stream: bool = False,
format: str = '', format: str = '',
options: Optional[Options] = None, options: Optional[Options] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
if not model: if not model:
raise Exception('must provide a model') raise Exception('must provide a model')
@ -96,47 +99,59 @@ class Client(BaseClient):
message['images'] = [_encode_image(image) for image in images] message['images'] = [_encode_image(image) for image in images]
fn = self._stream if stream else self._request_json fn = self._stream if stream else self._request_json
return fn('POST', '/api/chat', json={ return fn(
'model': model, 'POST',
'messages': messages, '/api/chat',
'stream': stream, json={
'format': format, 'model': model,
'options': options or {}, 'messages': messages,
}) 'stream': stream,
'format': format,
'options': options or {},
},
)
def pull( def pull(
self, self,
model: str, model: str,
insecure: bool = False, insecure: bool = False,
stream: bool = False, stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json fn = self._stream if stream else self._request_json
return fn('POST', '/api/pull', json={ return fn(
'model': model, 'POST',
'insecure': insecure, '/api/pull',
'stream': stream, json={
}) 'model': model,
'insecure': insecure,
'stream': stream,
},
)
def push( def push(
self, self,
model: str, model: str,
insecure: bool = False, insecure: bool = False,
stream: bool = False, stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json fn = self._stream if stream else self._request_json
return fn('POST', '/api/push', json={ return fn(
'model': model, 'POST',
'insecure': insecure, '/api/push',
'stream': stream, json={
}) 'model': model,
'insecure': insecure,
'stream': stream,
},
)
def create( def create(
self, self,
model: str, model: str,
path: Optional[Union[str, PathLike]] = None, path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None, modelfile: Optional[str] = None,
stream: bool = False, stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
if (realpath := _as_path(path)) and realpath.exists(): if (realpath := _as_path(path)) and realpath.exists():
modelfile = self._parse_modelfile(realpath.read_text(), base=realpath.parent) modelfile = self._parse_modelfile(realpath.read_text(), base=realpath.parent)
elif modelfile: elif modelfile:
@ -145,11 +160,15 @@ class Client(BaseClient):
raise Exception('must provide either path or modelfile') raise Exception('must provide either path or modelfile')
fn = self._stream if stream else self._request_json fn = self._stream if stream else self._request_json
return fn('POST', '/api/create', json={ return fn(
'model': model, 'POST',
'modelfile': modelfile, '/api/create',
'stream': stream, json={
}) 'model': model,
'modelfile': modelfile,
'stream': stream,
},
)
def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str: def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
base = Path.cwd() if base is None else base base = Path.cwd() if base is None else base
@ -170,7 +189,7 @@ class Client(BaseClient):
sha256sum = sha256() sha256sum = sha256()
with open(path, 'rb') as r: with open(path, 'rb') as r:
while True: while True:
chunk = r.read(32*1024) chunk = r.read(32 * 1024)
if not chunk: if not chunk:
break break
sha256sum.update(chunk) sha256sum.update(chunk)
@ -204,7 +223,6 @@ class Client(BaseClient):
class AsyncClient(BaseClient): class AsyncClient(BaseClient):
def __init__(self, base='http://localhost:11434') -> None: def __init__(self, base='http://localhost:11434') -> None:
super().__init__(httpx.AsyncClient, base) super().__init__(httpx.AsyncClient, base)
@ -225,46 +243,51 @@ class AsyncClient(BaseClient):
if e := part.get('error'): if e := part.get('error'):
raise Exception(e) raise Exception(e)
yield part yield part
return inner() return inner()
async def generate( async def generate(
self, self,
model: str = '', model: str = '',
prompt: str = '', prompt: str = '',
system: str = '', system: str = '',
template: str = '', template: str = '',
context: Optional[List[int]] = None, context: Optional[List[int]] = None,
stream: bool = False, stream: bool = False,
raw: bool = False, raw: bool = False,
format: str = '', format: str = '',
images: Optional[List[AnyStr]] = None, images: Optional[List[AnyStr]] = None,
options: Optional[Options] = None, options: Optional[Options] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
if not model: if not model:
raise Exception('must provide a model') raise Exception('must provide a model')
fn = self._stream if stream else self._request_json fn = self._stream if stream else self._request_json
return await fn('POST', '/api/generate', json={ return await fn(
'model': model, 'POST',
'prompt': prompt, '/api/generate',
'system': system, json={
'template': template, 'model': model,
'context': context or [], 'prompt': prompt,
'stream': stream, 'system': system,
'raw': raw, 'template': template,
'images': [_encode_image(image) for image in images or []], 'context': context or [],
'format': format, 'stream': stream,
'options': options or {}, 'raw': raw,
}) 'images': [_encode_image(image) for image in images or []],
'format': format,
'options': options or {},
},
)
async def chat( async def chat(
self, self,
model: str = '', model: str = '',
messages: Optional[List[Message]] = None, messages: Optional[List[Message]] = None,
stream: bool = False, stream: bool = False,
format: str = '', format: str = '',
options: Optional[Options] = None, options: Optional[Options] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
if not model: if not model:
raise Exception('must provide a model') raise Exception('must provide a model')
@ -279,47 +302,59 @@ class AsyncClient(BaseClient):
message['images'] = [_encode_image(image) for image in images] message['images'] = [_encode_image(image) for image in images]
fn = self._stream if stream else self._request_json fn = self._stream if stream else self._request_json
return await fn('POST', '/api/chat', json={ return await fn(
'model': model, 'POST',
'messages': messages, '/api/chat',
'stream': stream, json={
'format': format, 'model': model,
'options': options or {}, 'messages': messages,
}) 'stream': stream,
'format': format,
'options': options or {},
},
)
async def pull( async def pull(
self, self,
model: str, model: str,
insecure: bool = False, insecure: bool = False,
stream: bool = False, stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json fn = self._stream if stream else self._request_json
return await fn('POST', '/api/pull', json={ return await fn(
'model': model, 'POST',
'insecure': insecure, '/api/pull',
'stream': stream, json={
}) 'model': model,
'insecure': insecure,
'stream': stream,
},
)
async def push( async def push(
self, self,
model: str, model: str,
insecure: bool = False, insecure: bool = False,
stream: bool = False, stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json fn = self._stream if stream else self._request_json
return await fn('POST', '/api/push', json={ return await fn(
'model': model, 'POST',
'insecure': insecure, '/api/push',
'stream': stream, json={
}) 'model': model,
'insecure': insecure,
'stream': stream,
},
)
async def create( async def create(
self, self,
model: str, model: str,
path: Optional[Union[str, PathLike]] = None, path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None, modelfile: Optional[str] = None,
stream: bool = False, stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
if (realpath := _as_path(path)) and realpath.exists(): if (realpath := _as_path(path)) and realpath.exists():
modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent) modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent)
elif modelfile: elif modelfile:
@ -328,11 +363,15 @@ class AsyncClient(BaseClient):
raise Exception('must provide either path or modelfile') raise Exception('must provide either path or modelfile')
fn = self._stream if stream else self._request_json fn = self._stream if stream else self._request_json
return await fn('POST', '/api/create', json={ return await fn(
'model': model, 'POST',
'modelfile': modelfile, '/api/create',
'stream': stream, json={
}) 'model': model,
'modelfile': modelfile,
'stream': stream,
},
)
async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str: async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
base = Path.cwd() if base is None else base base = Path.cwd() if base is None else base
@ -353,7 +392,7 @@ class AsyncClient(BaseClient):
sha256sum = sha256() sha256sum = sha256()
with open(path, 'rb') as r: with open(path, 'rb') as r:
while True: while True:
chunk = r.read(32*1024) chunk = r.read(32 * 1024)
if not chunk: if not chunk:
break break
sha256sum.update(chunk) sha256sum.update(chunk)
@ -369,7 +408,7 @@ class AsyncClient(BaseClient):
async def upload_bytes(): async def upload_bytes():
with open(path, 'rb') as r: with open(path, 'rb') as r:
while True: while True:
chunk = r.read(32*1024) chunk = r.read(32 * 1024)
if not chunk: if not chunk:
break break
yield chunk yield chunk

View File

@ -1,6 +1,7 @@
from typing import Any, TypedDict, List from typing import Any, TypedDict, List
import sys import sys
if sys.version_info < (3, 11): if sys.version_info < (3, 11):
from typing_extensions import NotRequired from typing_extensions import NotRequired
else: else:

View File

@ -25,6 +25,7 @@ requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.ruff] [tool.ruff]
line-length = 999
indent-width = 2 indent-width = 2
[tool.ruff.format] [tool.ruff.format]

View File

@ -21,19 +21,25 @@ class PrefixPattern(URIPattern):
def test_client_chat(httpserver: HTTPServer): def test_client_chat(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/chat', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/chat',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], method='POST',
'stream': False, json={
'format': '', 'model': 'dummy',
'options': {}, 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
}).respond_with_json({ 'stream': False,
'model': 'dummy', 'format': '',
'message': { 'options': {},
'role': 'assistant', },
'content': "I don't know.", ).respond_with_json(
}, {
}) 'model': 'dummy',
'message': {
'role': 'assistant',
'content': "I don't know.",
},
}
)
client = Client(httpserver.url_for('/')) client = Client(httpserver.url_for('/'))
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}]) response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
@ -46,22 +52,32 @@ def test_client_chat_stream(httpserver: HTTPServer):
def stream_handler(_: Request): def stream_handler(_: Request):
def generate(): def generate():
for message in ['I ', "don't ", 'know.']: for message in ['I ', "don't ", 'know.']:
yield json.dumps({ yield (
'model': 'dummy', json.dumps(
'message': { {
'role': 'assistant', 'model': 'dummy',
'content': message, 'message': {
}, 'role': 'assistant',
}) + '\n' 'content': message,
},
}
)
+ '\n'
)
return Response(generate()) return Response(generate())
httpserver.expect_ordered_request('/api/chat', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/chat',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], method='POST',
'stream': True, json={
'format': '', 'model': 'dummy',
'options': {}, 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
}).respond_with_handler(stream_handler) 'stream': True,
'format': '',
'options': {},
},
).respond_with_handler(stream_handler)
client = Client(httpserver.url_for('/')) client = Client(httpserver.url_for('/'))
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True) response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
@ -71,25 +87,31 @@ def test_client_chat_stream(httpserver: HTTPServer):
def test_client_chat_images(httpserver: HTTPServer): def test_client_chat_images(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/chat', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/chat',
'messages': [ method='POST',
{ json={
'role': 'user', 'model': 'dummy',
'content': 'Why is the sky blue?', 'messages': [
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], {
'role': 'user',
'content': 'Why is the sky blue?',
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
},
],
'stream': False,
'format': '',
'options': {},
},
).respond_with_json(
{
'model': 'dummy',
'message': {
'role': 'assistant',
'content': "I don't know.",
}, },
], }
'stream': False, )
'format': '',
'options': {},
}).respond_with_json({
'model': 'dummy',
'message': {
'role': 'assistant',
'content': "I don't know.",
},
})
client = Client(httpserver.url_for('/')) client = Client(httpserver.url_for('/'))
@ -102,21 +124,27 @@ def test_client_chat_images(httpserver: HTTPServer):
def test_client_generate(httpserver: HTTPServer): def test_client_generate(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/generate', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/generate',
'prompt': 'Why is the sky blue?', method='POST',
'system': '', json={
'template': '', 'model': 'dummy',
'context': [], 'prompt': 'Why is the sky blue?',
'stream': False, 'system': '',
'raw': False, 'template': '',
'images': [], 'context': [],
'format': '', 'stream': False,
'options': {}, 'raw': False,
}).respond_with_json({ 'images': [],
'model': 'dummy', 'format': '',
'response': 'Because it is.', 'options': {},
}) },
).respond_with_json(
{
'model': 'dummy',
'response': 'Because it is.',
}
)
client = Client(httpserver.url_for('/')) client = Client(httpserver.url_for('/'))
response = client.generate('dummy', 'Why is the sky blue?') response = client.generate('dummy', 'Why is the sky blue?')
@ -128,24 +156,34 @@ def test_client_generate_stream(httpserver: HTTPServer):
def stream_handler(_: Request): def stream_handler(_: Request):
def generate(): def generate():
for message in ['Because ', 'it ', 'is.']: for message in ['Because ', 'it ', 'is.']:
yield json.dumps({ yield (
'model': 'dummy', json.dumps(
'response': message, {
}) + '\n' 'model': 'dummy',
'response': message,
}
)
+ '\n'
)
return Response(generate()) return Response(generate())
httpserver.expect_ordered_request('/api/generate', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/generate',
'prompt': 'Why is the sky blue?', method='POST',
'system': '', json={
'template': '', 'model': 'dummy',
'context': [], 'prompt': 'Why is the sky blue?',
'stream': True, 'system': '',
'raw': False, 'template': '',
'images': [], 'context': [],
'format': '', 'stream': True,
'options': {}, 'raw': False,
}).respond_with_handler(stream_handler) 'images': [],
'format': '',
'options': {},
},
).respond_with_handler(stream_handler)
client = Client(httpserver.url_for('/')) client = Client(httpserver.url_for('/'))
response = client.generate('dummy', 'Why is the sky blue?', stream=True) response = client.generate('dummy', 'Why is the sky blue?', stream=True)
@ -155,21 +193,27 @@ def test_client_generate_stream(httpserver: HTTPServer):
def test_client_generate_images(httpserver: HTTPServer): def test_client_generate_images(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/generate', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/generate',
'prompt': 'Why is the sky blue?', method='POST',
'system': '', json={
'template': '', 'model': 'dummy',
'context': [], 'prompt': 'Why is the sky blue?',
'stream': False, 'system': '',
'raw': False, 'template': '',
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], 'context': [],
'format': '', 'stream': False,
'options': {}, 'raw': False,
}).respond_with_json({ 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
'model': 'dummy', 'format': '',
'response': 'Because it is.', 'options': {},
}) },
).respond_with_json(
{
'model': 'dummy',
'response': 'Because it is.',
}
)
client = Client(httpserver.url_for('/')) client = Client(httpserver.url_for('/'))
@ -181,13 +225,19 @@ def test_client_generate_images(httpserver: HTTPServer):
def test_client_pull(httpserver: HTTPServer): def test_client_pull(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/pull', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/pull',
'insecure': False, method='POST',
'stream': False, json={
}).respond_with_json({ 'model': 'dummy',
'status': 'success', 'insecure': False,
}) 'stream': False,
},
).respond_with_json(
{
'status': 'success',
}
)
client = Client(httpserver.url_for('/')) client = Client(httpserver.url_for('/'))
response = client.pull('dummy') response = client.pull('dummy')
@ -202,13 +252,18 @@ def test_client_pull_stream(httpserver: HTTPServer):
yield json.dumps({'status': 'writing manifest'}) + '\n' yield json.dumps({'status': 'writing manifest'}) + '\n'
yield json.dumps({'status': 'removing any unused layers'}) + '\n' yield json.dumps({'status': 'removing any unused layers'}) + '\n'
yield json.dumps({'status': 'success'}) + '\n' yield json.dumps({'status': 'success'}) + '\n'
return Response(generate()) return Response(generate())
httpserver.expect_ordered_request('/api/pull', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/pull',
'insecure': False, method='POST',
'stream': True, json={
}).respond_with_json({}) 'model': 'dummy',
'insecure': False,
'stream': True,
},
).respond_with_json({})
client = Client(httpserver.url_for('/')) client = Client(httpserver.url_for('/'))
response = client.pull('dummy', stream=True) response = client.pull('dummy', stream=True)
@ -216,11 +271,15 @@ def test_client_pull_stream(httpserver: HTTPServer):
def test_client_push(httpserver: HTTPServer): def test_client_push(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/push', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/push',
'insecure': False, method='POST',
'stream': False, json={
}).respond_with_json({}) 'model': 'dummy',
'insecure': False,
'stream': False,
},
).respond_with_json({})
client = Client(httpserver.url_for('/')) client = Client(httpserver.url_for('/'))
response = client.push('dummy') response = client.push('dummy')
@ -228,11 +287,15 @@ def test_client_push(httpserver: HTTPServer):
def test_client_push_stream(httpserver: HTTPServer): def test_client_push_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/push', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/push',
'insecure': False, method='POST',
'stream': True, json={
}).respond_with_json({}) 'model': 'dummy',
'insecure': False,
'stream': True,
},
).respond_with_json({})
client = Client(httpserver.url_for('/')) client = Client(httpserver.url_for('/'))
response = client.push('dummy', stream=True) response = client.push('dummy', stream=True)
@ -241,11 +304,15 @@ def test_client_push_stream(httpserver: HTTPServer):
def test_client_create_path(httpserver: HTTPServer): def test_client_create_path(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/create',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', method='POST',
'stream': False, json={
}).respond_with_json({}) 'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({})
client = Client(httpserver.url_for('/')) client = Client(httpserver.url_for('/'))
@ -260,11 +327,15 @@ def test_client_create_path(httpserver: HTTPServer):
def test_client_create_path_relative(httpserver: HTTPServer): def test_client_create_path_relative(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/create',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', method='POST',
'stream': False, json={
}).respond_with_json({}) 'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({})
client = Client(httpserver.url_for('/')) client = Client(httpserver.url_for('/'))
@ -288,11 +359,15 @@ def userhomedir():
def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir): def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/create',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', method='POST',
'stream': False, json={
}).respond_with_json({}) 'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({})
client = Client(httpserver.url_for('/')) client = Client(httpserver.url_for('/'))
@ -307,11 +382,15 @@ def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
def test_client_create_modelfile(httpserver: HTTPServer): def test_client_create_modelfile(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/create',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', method='POST',
'stream': False, json={
}).respond_with_json({}) 'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({})
client = Client(httpserver.url_for('/')) client = Client(httpserver.url_for('/'))
@ -321,11 +400,15 @@ def test_client_create_modelfile(httpserver: HTTPServer):
def test_client_create_from_library(httpserver: HTTPServer): def test_client_create_from_library(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/create', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/create',
'modelfile': 'FROM llama2\n', method='POST',
'stream': False, json={
}).respond_with_json({}) 'model': 'dummy',
'modelfile': 'FROM llama2\n',
'stream': False,
},
).respond_with_json({})
client = Client(httpserver.url_for('/')) client = Client(httpserver.url_for('/'))
@ -356,13 +439,17 @@ def test_client_create_blob_exists(httpserver: HTTPServer):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_client_chat(httpserver: HTTPServer): async def test_async_client_chat(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/chat', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/chat',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], method='POST',
'stream': False, json={
'format': '', 'model': 'dummy',
'options': {}, 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
}).respond_with_json({}) 'stream': False,
'format': '',
'options': {},
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/')) client = AsyncClient(httpserver.url_for('/'))
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}]) response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
@ -371,13 +458,17 @@ async def test_async_client_chat(httpserver: HTTPServer):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_client_chat_stream(httpserver: HTTPServer): async def test_async_client_chat_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/chat', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/chat',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], method='POST',
'stream': True, json={
'format': '', 'model': 'dummy',
'options': {}, 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
}).respond_with_json({}) 'stream': True,
'format': '',
'options': {},
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/')) client = AsyncClient(httpserver.url_for('/'))
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True) response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
@ -386,19 +477,23 @@ async def test_async_client_chat_stream(httpserver: HTTPServer):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_client_chat_images(httpserver: HTTPServer): async def test_async_client_chat_images(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/chat', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/chat',
'messages': [ method='POST',
{ json={
'role': 'user', 'model': 'dummy',
'content': 'Why is the sky blue?', 'messages': [
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], {
}, 'role': 'user',
], 'content': 'Why is the sky blue?',
'stream': False, 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
'format': '', },
'options': {}, ],
}).respond_with_json({}) 'stream': False,
'format': '',
'options': {},
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/')) client = AsyncClient(httpserver.url_for('/'))
@ -410,18 +505,22 @@ async def test_async_client_chat_images(httpserver: HTTPServer):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_client_generate(httpserver: HTTPServer): async def test_async_client_generate(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/generate', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/generate',
'prompt': 'Why is the sky blue?', method='POST',
'system': '', json={
'template': '', 'model': 'dummy',
'context': [], 'prompt': 'Why is the sky blue?',
'stream': False, 'system': '',
'raw': False, 'template': '',
'images': [], 'context': [],
'format': '', 'stream': False,
'options': {}, 'raw': False,
}).respond_with_json({}) 'images': [],
'format': '',
'options': {},
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/')) client = AsyncClient(httpserver.url_for('/'))
response = await client.generate('dummy', 'Why is the sky blue?') response = await client.generate('dummy', 'Why is the sky blue?')
@ -430,18 +529,22 @@ async def test_async_client_generate(httpserver: HTTPServer):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_client_generate_stream(httpserver: HTTPServer): async def test_async_client_generate_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/generate', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/generate',
'prompt': 'Why is the sky blue?', method='POST',
'system': '', json={
'template': '', 'model': 'dummy',
'context': [], 'prompt': 'Why is the sky blue?',
'stream': True, 'system': '',
'raw': False, 'template': '',
'images': [], 'context': [],
'format': '', 'stream': True,
'options': {}, 'raw': False,
}).respond_with_json({}) 'images': [],
'format': '',
'options': {},
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/')) client = AsyncClient(httpserver.url_for('/'))
response = await client.generate('dummy', 'Why is the sky blue?', stream=True) response = await client.generate('dummy', 'Why is the sky blue?', stream=True)
@ -450,18 +553,22 @@ async def test_async_client_generate_stream(httpserver: HTTPServer):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_client_generate_images(httpserver: HTTPServer): async def test_async_client_generate_images(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/generate', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/generate',
'prompt': 'Why is the sky blue?', method='POST',
'system': '', json={
'template': '', 'model': 'dummy',
'context': [], 'prompt': 'Why is the sky blue?',
'stream': False, 'system': '',
'raw': False, 'template': '',
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], 'context': [],
'format': '', 'stream': False,
'options': {}, 'raw': False,
}).respond_with_json({}) 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
'format': '',
'options': {},
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/')) client = AsyncClient(httpserver.url_for('/'))
@ -473,11 +580,15 @@ async def test_async_client_generate_images(httpserver: HTTPServer):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_client_pull(httpserver: HTTPServer): async def test_async_client_pull(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/pull', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/pull',
'insecure': False, method='POST',
'stream': False, json={
}).respond_with_json({}) 'model': 'dummy',
'insecure': False,
'stream': False,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/')) client = AsyncClient(httpserver.url_for('/'))
response = await client.pull('dummy') response = await client.pull('dummy')
@ -486,11 +597,15 @@ async def test_async_client_pull(httpserver: HTTPServer):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_client_pull_stream(httpserver: HTTPServer): async def test_async_client_pull_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/pull', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/pull',
'insecure': False, method='POST',
'stream': True, json={
}).respond_with_json({}) 'model': 'dummy',
'insecure': False,
'stream': True,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/')) client = AsyncClient(httpserver.url_for('/'))
response = await client.pull('dummy', stream=True) response = await client.pull('dummy', stream=True)
@ -499,11 +614,15 @@ async def test_async_client_pull_stream(httpserver: HTTPServer):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_client_push(httpserver: HTTPServer): async def test_async_client_push(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/push', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/push',
'insecure': False, method='POST',
'stream': False, json={
}).respond_with_json({}) 'model': 'dummy',
'insecure': False,
'stream': False,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/')) client = AsyncClient(httpserver.url_for('/'))
response = await client.push('dummy') response = await client.push('dummy')
@ -512,11 +631,15 @@ async def test_async_client_push(httpserver: HTTPServer):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_client_push_stream(httpserver: HTTPServer): async def test_async_client_push_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/push', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/push',
'insecure': False, method='POST',
'stream': True, json={
}).respond_with_json({}) 'model': 'dummy',
'insecure': False,
'stream': True,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/')) client = AsyncClient(httpserver.url_for('/'))
response = await client.push('dummy', stream=True) response = await client.push('dummy', stream=True)
@ -526,11 +649,15 @@ async def test_async_client_push_stream(httpserver: HTTPServer):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_client_create_path(httpserver: HTTPServer): async def test_async_client_create_path(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/create',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', method='POST',
'stream': False, json={
}).respond_with_json({}) 'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/')) client = AsyncClient(httpserver.url_for('/'))
@ -546,11 +673,15 @@ async def test_async_client_create_path(httpserver: HTTPServer):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_client_create_path_relative(httpserver: HTTPServer): async def test_async_client_create_path_relative(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/create',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', method='POST',
'stream': False, json={
}).respond_with_json({}) 'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/')) client = AsyncClient(httpserver.url_for('/'))
@ -566,11 +697,15 @@ async def test_async_client_create_path_relative(httpserver: HTTPServer):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_client_create_path_user_home(httpserver: HTTPServer, userhomedir): async def test_async_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/create',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', method='POST',
'stream': False, json={
}).respond_with_json({}) 'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/')) client = AsyncClient(httpserver.url_for('/'))
@ -586,11 +721,15 @@ async def test_async_client_create_path_user_home(httpserver: HTTPServer, userho
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_client_create_modelfile(httpserver: HTTPServer): async def test_async_client_create_modelfile(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200)) httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/create',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', method='POST',
'stream': False, json={
}).respond_with_json({}) 'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/')) client = AsyncClient(httpserver.url_for('/'))
@ -601,11 +740,15 @@ async def test_async_client_create_modelfile(httpserver: HTTPServer):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_client_create_from_library(httpserver: HTTPServer): async def test_async_client_create_from_library(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/create', method='POST', json={ httpserver.expect_ordered_request(
'model': 'dummy', '/api/create',
'modelfile': 'FROM llama2\n', method='POST',
'stream': False, json={
}).respond_with_json({}) 'model': 'dummy',
'modelfile': 'FROM llama2\n',
'stream': False,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/')) client = AsyncClient(httpserver.url_for('/'))