diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml new file mode 100644 index 0000000..a3728ca --- /dev/null +++ b/.github/workflows/publish.yaml @@ -0,0 +1,24 @@ +name: publish + +on: + release: + types: + - created + +jobs: + publish: + runs-on: ubuntu-latest + environment: release + permissions: + id-token: write + steps: + - uses: actions/checkout@v4 + - run: pipx install poetry + - uses: actions/setup-python@v5 + with: + cache: poetry + - run: | + poetry version -- ${GIT_REF_NAME#v} + poetry build + - uses: pypa/gh-action-pypi-publish@release/v1 + - run: gh release upload $GIT_REF_NAME dist/* diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..cc3a2d5 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,32 @@ +name: test + +on: + pull_request: + +jobs: + test: + strategy: + matrix: + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - run: pipx install poetry + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: poetry + - run: poetry install --with=dev + - run: poetry run ruff --output-format=github . + - run: poetry run pytest . --junitxml=junit/test-results-${{ matrix.python-version }}.xml --cov=ollama --cov-report=xml --cov-report=html + - name: check poetry.lock is up-to-date + run: poetry check --lock + - name: check requirements.txt is up-to-date + run: | + poetry export >requirements.txt + git diff --exit-code requirements.txt + - uses: actions/upload-artifact@v3 + with: + name: pytest-results-${{ matrix.python-version }} + path: junit/test-results-${{ matrix.python-version }}.xml + if: ${{ always() }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..68bc17f --- /dev/null +++ b/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/README.md b/README.md index e69de29..0183092 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,70 @@ +# Ollama Python Library + +The Ollama Python library provides the easiest way to integrate your Python 3 project with [Ollama](https://github.com/jmorganca/ollama). + +## Getting Started + +Requires Python 3.8 or higher. + +```sh +pip install ollama +``` + +A global default client is provided for convenience and can be used in the same way as the synchronous client. + +```python +import ollama +response = ollama.chat(model='llama2', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}]) +``` + +```python +import ollama +message = {'role': 'user', 'content': 'Why is the sky blue?'} +for part in ollama.chat(model='llama2', messages=[message], stream=True): + print(part['message']['content'], end='', flush=True) +``` + + +### Using the Synchronous Client + +```python +from ollama import Client +message = {'role': 'user', 'content': 'Why is the sky blue?'} +response = Client().chat(model='llama2', messages=[message]) +``` + +Response streaming can be enabled by setting `stream=True`. This modifies the function to return a Python generator where each part is an object in the stream. + +```python +from ollama import Client +message = {'role': 'user', 'content': 'Why is the sky blue?'} +for part in Client().chat(model='llama2', messages=[message], stream=True): + print(part['message']['content'], end='', flush=True) +``` + +### Using the Asynchronous Client + +```python +import asyncio +from ollama import AsyncClient + +async def chat(): + message = {'role': 'user', 'content': 'Why is the sky blue?'} + response = await AsyncClient().chat(model='llama2', messages=[message]) + +asyncio.run(chat()) +``` + +Similar to the synchronous client, setting `stream=True` modifies the function to return a Python asynchronous generator. + +```python +import asyncio +from ollama import AsyncClient + +async def chat(): + message = {'role': 'user', 'content': 'Why is the sky blue?'} + async for part in await AsyncClient().chat(model='llama2', messages=[message], stream=True): + print(part['message']['content'], end='', flush=True) + +asyncio.run(chat()) +``` diff --git a/examples/simple-fill-in-middle/main.py b/examples/simple-fill-in-middle/main.py new file mode 100644 index 0000000..67d7a74 --- /dev/null +++ b/examples/simple-fill-in-middle/main.py @@ -0,0 +1,22 @@ +from ollama import generate + +prefix = '''def remove_non_ascii(s: str) -> str: + """ ''' + +suffix = """ + return result +""" + + +response = generate( + model='codellama:7b-code', + prompt=f'
{prefix} {suffix} ',
+ options={
+ 'num_predict': 128,
+ 'temperature': 0,
+ 'top_p': 0.9,
+ 'stop': [''],
+ },
+)
+
+print(response['response'])
diff --git a/ollama/__init__.py b/ollama/__init__.py
index 8e7dc22..a66f1d0 100644
--- a/ollama/__init__.py
+++ b/ollama/__init__.py
@@ -1,30 +1,56 @@
-from ollama.client import Client
+from ollama._client import Client, AsyncClient, Message, Options
+
+__all__ = [
+ 'Client',
+ 'AsyncClient',
+ 'Message',
+ 'Options',
+ 'generate',
+ 'chat',
+ 'pull',
+ 'push',
+ 'create',
+ 'delete',
+ 'list',
+ 'copy',
+ 'show',
+]
+
_default_client = Client()
+
def generate(*args, **kwargs):
return _default_client.generate(*args, **kwargs)
+
def chat(*args, **kwargs):
return _default_client.chat(*args, **kwargs)
+
def pull(*args, **kwargs):
return _default_client.pull(*args, **kwargs)
+
def push(*args, **kwargs):
return _default_client.push(*args, **kwargs)
+
def create(*args, **kwargs):
return _default_client.create(*args, **kwargs)
+
def delete(*args, **kwargs):
return _default_client.delete(*args, **kwargs)
+
def list(*args, **kwargs):
return _default_client.list(*args, **kwargs)
+
def copy(*args, **kwargs):
return _default_client.copy(*args, **kwargs)
+
def show(*args, **kwargs):
return _default_client.show(*args, **kwargs)
diff --git a/ollama/_client.py b/ollama/_client.py
new file mode 100644
index 0000000..d0fa30f
--- /dev/null
+++ b/ollama/_client.py
@@ -0,0 +1,458 @@
+import io
+import json
+import httpx
+from os import PathLike
+from pathlib import Path
+from hashlib import sha256
+from base64 import b64encode
+
+from typing import Any, AnyStr, Union, Optional, List, Mapping
+
+import sys
+
+if sys.version_info < (3, 9):
+ from typing import Iterator, AsyncIterator
+else:
+ from collections.abc import Iterator, AsyncIterator
+
+from ollama._types import Message, Options
+
+
+class BaseClient:
+ def __init__(self, client, base_url='http://127.0.0.1:11434') -> None:
+ self._client = client(base_url=base_url, follow_redirects=True, timeout=None)
+
+
+class Client(BaseClient):
+ def __init__(self, base='http://localhost:11434') -> None:
+ super().__init__(httpx.Client, base)
+
+ def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
+ response = self._client.request(method, url, **kwargs)
+ response.raise_for_status()
+ return response
+
+ def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]:
+ return self._request(method, url, **kwargs).json()
+
+ def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]:
+ with self._client.stream(method, url, **kwargs) as r:
+ for line in r.iter_lines():
+ part = json.loads(line)
+ if e := part.get('error'):
+ raise Exception(e)
+ yield part
+
+ def generate(
+ self,
+ model: str = '',
+ prompt: str = '',
+ system: str = '',
+ template: str = '',
+ context: Optional[List[int]] = None,
+ stream: bool = False,
+ raw: bool = False,
+ format: str = '',
+ images: Optional[List[AnyStr]] = None,
+ options: Optional[Options] = None,
+ ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
+ if not model:
+ raise Exception('must provide a model')
+
+ fn = self._stream if stream else self._request_json
+ return fn(
+ 'POST',
+ '/api/generate',
+ json={
+ 'model': model,
+ 'prompt': prompt,
+ 'system': system,
+ 'template': template,
+ 'context': context or [],
+ 'stream': stream,
+ 'raw': raw,
+ 'images': [_encode_image(image) for image in images or []],
+ 'format': format,
+ 'options': options or {},
+ },
+ )
+
+ def chat(
+ self,
+ model: str = '',
+ messages: Optional[List[Message]] = None,
+ stream: bool = False,
+ format: str = '',
+ options: Optional[Options] = None,
+ ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
+ if not model:
+ raise Exception('must provide a model')
+
+ for message in messages or []:
+ if not isinstance(message, dict):
+ raise TypeError('messages must be a list of strings')
+ if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
+ raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
+ if not message.get('content'):
+ raise Exception('messages must contain content')
+ if images := message.get('images'):
+ message['images'] = [_encode_image(image) for image in images]
+
+ fn = self._stream if stream else self._request_json
+ return fn(
+ 'POST',
+ '/api/chat',
+ json={
+ 'model': model,
+ 'messages': messages,
+ 'stream': stream,
+ 'format': format,
+ 'options': options or {},
+ },
+ )
+
+ def pull(
+ self,
+ model: str,
+ insecure: bool = False,
+ stream: bool = False,
+ ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
+ fn = self._stream if stream else self._request_json
+ return fn(
+ 'POST',
+ '/api/pull',
+ json={
+ 'model': model,
+ 'insecure': insecure,
+ 'stream': stream,
+ },
+ )
+
+ def push(
+ self,
+ model: str,
+ insecure: bool = False,
+ stream: bool = False,
+ ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
+ fn = self._stream if stream else self._request_json
+ return fn(
+ 'POST',
+ '/api/push',
+ json={
+ 'model': model,
+ 'insecure': insecure,
+ 'stream': stream,
+ },
+ )
+
+ def create(
+ self,
+ model: str,
+ path: Optional[Union[str, PathLike]] = None,
+ modelfile: Optional[str] = None,
+ stream: bool = False,
+ ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
+ if (realpath := _as_path(path)) and realpath.exists():
+ modelfile = self._parse_modelfile(realpath.read_text(), base=realpath.parent)
+ elif modelfile:
+ modelfile = self._parse_modelfile(modelfile)
+ else:
+ raise Exception('must provide either path or modelfile')
+
+ fn = self._stream if stream else self._request_json
+ return fn(
+ 'POST',
+ '/api/create',
+ json={
+ 'model': model,
+ 'modelfile': modelfile,
+ 'stream': stream,
+ },
+ )
+
+ def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
+ base = Path.cwd() if base is None else base
+
+ out = io.StringIO()
+ for line in io.StringIO(modelfile):
+ command, _, args = line.partition(' ')
+ if command.upper() in ['FROM', 'ADAPTER']:
+ path = Path(args).expanduser()
+ path = path if path.is_absolute() else base / path
+ if path.exists():
+ args = f'@{self._create_blob(path)}'
+
+ print(command, args, file=out)
+ return out.getvalue()
+
+ def _create_blob(self, path: Union[str, Path]) -> str:
+ sha256sum = sha256()
+ with open(path, 'rb') as r:
+ while True:
+ chunk = r.read(32 * 1024)
+ if not chunk:
+ break
+ sha256sum.update(chunk)
+
+ digest = f'sha256:{sha256sum.hexdigest()}'
+
+ try:
+ self._request('HEAD', f'/api/blobs/{digest}')
+ except httpx.HTTPStatusError as e:
+ if e.response.status_code != 404:
+ raise
+
+ with open(path, 'rb') as r:
+ self._request('PUT', f'/api/blobs/{digest}', content=r)
+
+ return digest
+
+ def delete(self, model: str) -> Mapping[str, Any]:
+ response = self._request('DELETE', '/api/delete', json={'model': model})
+ return {'status': 'success' if response.status_code == 200 else 'error'}
+
+ def list(self) -> Mapping[str, Any]:
+ return self._request_json('GET', '/api/tags').get('models', [])
+
+ def copy(self, source: str, target: str) -> Mapping[str, Any]:
+ response = self._request('POST', '/api/copy', json={'source': source, 'destination': target})
+ return {'status': 'success' if response.status_code == 200 else 'error'}
+
+ def show(self, model: str) -> Mapping[str, Any]:
+ return self._request_json('GET', '/api/show', json={'model': model})
+
+
+class AsyncClient(BaseClient):
+ def __init__(self, base='http://localhost:11434') -> None:
+ super().__init__(httpx.AsyncClient, base)
+
+ async def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
+ response = await self._client.request(method, url, **kwargs)
+ response.raise_for_status()
+ return response
+
+ async def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]:
+ response = await self._request(method, url, **kwargs)
+ return response.json()
+
+ async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mapping[str, Any]]:
+ async def inner():
+ async with self._client.stream(method, url, **kwargs) as r:
+ async for line in r.aiter_lines():
+ part = json.loads(line)
+ if e := part.get('error'):
+ raise Exception(e)
+ yield part
+
+ return inner()
+
+ async def generate(
+ self,
+ model: str = '',
+ prompt: str = '',
+ system: str = '',
+ template: str = '',
+ context: Optional[List[int]] = None,
+ stream: bool = False,
+ raw: bool = False,
+ format: str = '',
+ images: Optional[List[AnyStr]] = None,
+ options: Optional[Options] = None,
+ ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
+ if not model:
+ raise Exception('must provide a model')
+
+ fn = self._stream if stream else self._request_json
+ return await fn(
+ 'POST',
+ '/api/generate',
+ json={
+ 'model': model,
+ 'prompt': prompt,
+ 'system': system,
+ 'template': template,
+ 'context': context or [],
+ 'stream': stream,
+ 'raw': raw,
+ 'images': [_encode_image(image) for image in images or []],
+ 'format': format,
+ 'options': options or {},
+ },
+ )
+
+ async def chat(
+ self,
+ model: str = '',
+ messages: Optional[List[Message]] = None,
+ stream: bool = False,
+ format: str = '',
+ options: Optional[Options] = None,
+ ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
+ if not model:
+ raise Exception('must provide a model')
+
+ for message in messages or []:
+ if not isinstance(message, dict):
+ raise TypeError('messages must be a list of strings')
+ if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
+ raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
+ if not message.get('content'):
+ raise Exception('messages must contain content')
+ if images := message.get('images'):
+ message['images'] = [_encode_image(image) for image in images]
+
+ fn = self._stream if stream else self._request_json
+ return await fn(
+ 'POST',
+ '/api/chat',
+ json={
+ 'model': model,
+ 'messages': messages,
+ 'stream': stream,
+ 'format': format,
+ 'options': options or {},
+ },
+ )
+
+ async def pull(
+ self,
+ model: str,
+ insecure: bool = False,
+ stream: bool = False,
+ ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
+ fn = self._stream if stream else self._request_json
+ return await fn(
+ 'POST',
+ '/api/pull',
+ json={
+ 'model': model,
+ 'insecure': insecure,
+ 'stream': stream,
+ },
+ )
+
+ async def push(
+ self,
+ model: str,
+ insecure: bool = False,
+ stream: bool = False,
+ ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
+ fn = self._stream if stream else self._request_json
+ return await fn(
+ 'POST',
+ '/api/push',
+ json={
+ 'model': model,
+ 'insecure': insecure,
+ 'stream': stream,
+ },
+ )
+
+ async def create(
+ self,
+ model: str,
+ path: Optional[Union[str, PathLike]] = None,
+ modelfile: Optional[str] = None,
+ stream: bool = False,
+ ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
+ if (realpath := _as_path(path)) and realpath.exists():
+ modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent)
+ elif modelfile:
+ modelfile = await self._parse_modelfile(modelfile)
+ else:
+ raise Exception('must provide either path or modelfile')
+
+ fn = self._stream if stream else self._request_json
+ return await fn(
+ 'POST',
+ '/api/create',
+ json={
+ 'model': model,
+ 'modelfile': modelfile,
+ 'stream': stream,
+ },
+ )
+
+ async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
+ base = Path.cwd() if base is None else base
+
+ out = io.StringIO()
+ for line in io.StringIO(modelfile):
+ command, _, args = line.partition(' ')
+ if command.upper() in ['FROM', 'ADAPTER']:
+ path = Path(args).expanduser()
+ path = path if path.is_absolute() else base / path
+ if path.exists():
+ args = f'@{await self._create_blob(path)}'
+
+ print(command, args, file=out)
+ return out.getvalue()
+
+ async def _create_blob(self, path: Union[str, Path]) -> str:
+ sha256sum = sha256()
+ with open(path, 'rb') as r:
+ while True:
+ chunk = r.read(32 * 1024)
+ if not chunk:
+ break
+ sha256sum.update(chunk)
+
+ digest = f'sha256:{sha256sum.hexdigest()}'
+
+ try:
+ await self._request('HEAD', f'/api/blobs/{digest}')
+ except httpx.HTTPStatusError as e:
+ if e.response.status_code != 404:
+ raise
+
+ async def upload_bytes():
+ with open(path, 'rb') as r:
+ while True:
+ chunk = r.read(32 * 1024)
+ if not chunk:
+ break
+ yield chunk
+
+ await self._request('PUT', f'/api/blobs/{digest}', content=upload_bytes())
+
+ return digest
+
+ async def delete(self, model: str) -> Mapping[str, Any]:
+ response = await self._request('DELETE', '/api/delete', json={'model': model})
+ return {'status': 'success' if response.status_code == 200 else 'error'}
+
+ async def list(self) -> Mapping[str, Any]:
+ response = await self._request_json('GET', '/api/tags')
+ return response.get('models', [])
+
+ async def copy(self, source: str, target: str) -> Mapping[str, Any]:
+ response = await self._request('POST', '/api/copy', json={'source': source, 'destination': target})
+ return {'status': 'success' if response.status_code == 200 else 'error'}
+
+ async def show(self, model: str) -> Mapping[str, Any]:
+ return await self._request_json('GET', '/api/show', json={'model': model})
+
+
+def _encode_image(image) -> str:
+ if p := _as_path(image):
+ b64 = b64encode(p.read_bytes())
+ elif b := _as_bytesio(image):
+ b64 = b64encode(b.read())
+ else:
+ raise Exception('images must be a list of bytes, path-like objects, or file-like objects')
+
+ return b64.decode('utf-8')
+
+
+def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:
+ if isinstance(s, str) or isinstance(s, Path):
+ return Path(s)
+ return None
+
+
+def _as_bytesio(s: Any) -> Union[io.BytesIO, None]:
+ if isinstance(s, io.BytesIO):
+ return s
+ elif isinstance(s, bytes):
+ return io.BytesIO(s)
+ return None
diff --git a/ollama/_types.py b/ollama/_types.py
new file mode 100644
index 0000000..d263269
--- /dev/null
+++ b/ollama/_types.py
@@ -0,0 +1,53 @@
+from typing import Any, TypedDict, List
+
+import sys
+
+if sys.version_info < (3, 11):
+ from typing_extensions import NotRequired
+else:
+ from typing import NotRequired
+
+
+class Message(TypedDict):
+ role: str
+ content: str
+ images: NotRequired[List[Any]]
+
+
+class Options(TypedDict, total=False):
+ # load time options
+ numa: bool
+ num_ctx: int
+ num_batch: int
+ num_gqa: int
+ num_gpu: int
+ main_gpu: int
+ low_vram: bool
+ f16_kv: bool
+ logits_all: bool
+ vocab_only: bool
+ use_mmap: bool
+ use_mlock: bool
+ embedding_only: bool
+ rope_frequency_base: float
+ rope_frequency_scale: float
+ num_thread: int
+
+ # runtime options
+ num_keep: int
+ seed: int
+ num_predict: int
+ top_k: int
+ top_p: float
+ tfs_z: float
+ typical_p: float
+ repeat_last_n: int
+ temperature: float
+ repeat_penalty: float
+ presence_penalty: float
+ frequency_penalty: float
+ mirostat: int
+ mirostat_tau: float
+ mirostat_eta: float
+ penalize_newline: bool
+ stop: List[str]
diff --git a/ollama/client.py b/ollama/client.py
deleted file mode 100644
index 578757f..0000000
--- a/ollama/client.py
+++ /dev/null
@@ -1,182 +0,0 @@
-import io
-import json
-import httpx
-from pathlib import Path
-from hashlib import sha256
-from base64 import b64encode
-
-
-class BaseClient:
-
- def __init__(self, client, base_url='http://127.0.0.1:11434'):
- self._client = client(base_url=base_url, follow_redirects=True, timeout=None)
-
-
-class Client(BaseClient):
-
- def __init__(self, base='http://localhost:11434'):
- super().__init__(httpx.Client, base)
-
- def _request(self, method, url, **kwargs):
- response = self._client.request(method, url, **kwargs)
- response.raise_for_status()
- return response
-
- def _request_json(self, method, url, **kwargs):
- return self._request(method, url, **kwargs).json()
-
- def stream(self, method, url, **kwargs):
- with self._client.stream(method, url, **kwargs) as r:
- for line in r.iter_lines():
- part = json.loads(line)
- if e := part.get('error'):
- raise Exception(e)
- yield part
-
- def generate(self, model, prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None):
- fn = self.stream if stream else self._request_json
- return fn('POST', '/api/generate', json={
- 'model': model,
- 'prompt': prompt,
- 'system': system,
- 'template': template,
- 'context': context or [],
- 'stream': stream,
- 'raw': raw,
- 'images': [_encode_image(image) for image in images or []],
- 'format': format,
- 'options': options or {},
- })
-
- def chat(self, model, messages=None, stream=False, format='', options=None):
- for message in messages or []:
- if not isinstance(message, dict):
- raise TypeError('messages must be a list of strings')
- if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
- raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
- if not message.get('content'):
- raise Exception('messages must contain content')
- if images := message.get('images'):
- message['images'] = [_encode_image(image) for image in images]
-
- fn = self.stream if stream else self._request_json
- return fn('POST', '/api/chat', json={
- 'model': model,
- 'messages': messages,
- 'stream': stream,
- 'format': format,
- 'options': options or {},
- })
-
- def pull(self, model, insecure=False, stream=False):
- fn = self.stream if stream else self._request_json
- return fn('POST', '/api/pull', json={
- 'model': model,
- 'insecure': insecure,
- 'stream': stream,
- })
-
- def push(self, model, insecure=False, stream=False):
- fn = self.stream if stream else self._request_json
- return fn('POST', '/api/push', json={
- 'model': model,
- 'insecure': insecure,
- 'stream': stream,
- })
-
- def create(self, model, path=None, modelfile=None, stream=False):
- if (path := _as_path(path)) and path.exists():
- modelfile = _parse_modelfile(path.read_text(), self.create_blob, base=path.parent)
- elif modelfile:
- modelfile = _parse_modelfile(modelfile, self.create_blob)
- else:
- raise Exception('must provide either path or modelfile')
-
- fn = self.stream if stream else self._request_json
- return fn('POST', '/api/create', json={
- 'model': model,
- 'modelfile': modelfile,
- 'stream': stream,
- })
-
-
- def create_blob(self, path):
- sha256sum = sha256()
- with open(path, 'rb') as r:
- while True:
- chunk = r.read(32*1024)
- if not chunk:
- break
- sha256sum.update(chunk)
-
- digest = f'sha256:{sha256sum.hexdigest()}'
-
- try:
- self._request('HEAD', f'/api/blobs/{digest}')
- except httpx.HTTPError:
- with open(path, 'rb') as r:
- self._request('PUT', f'/api/blobs/{digest}', content=r)
-
- return digest
-
- def delete(self, model):
- response = self._request_json('DELETE', '/api/delete', json={'model': model})
- return {'status': 'success' if response.status_code == 200 else 'error'}
-
- def list(self):
- return self._request_json('GET', '/api/tags').get('models', [])
-
- def copy(self, source, target):
- response = self._request_json('POST', '/api/copy', json={'source': source, 'destination': target})
- return {'status': 'success' if response.status_code == 200 else 'error'}
-
- def show(self, model):
- return self._request_json('GET', '/api/show', json={'model': model}).json()
-
-
-def _encode_image(image):
- '''
- _encode_images takes a list of images and returns a generator of base64 encoded images.
- if the image is a bytes object, it is assumed to be the raw bytes of an image.
- if the image is a string, it is assumed to be a path to a file.
- if the image is a Path object, it is assumed to be a path to a file.
- if the image is a file-like object, it is assumed to be a container to the raw bytes of an image.
- '''
-
- if p := _as_path(image):
- b64 = b64encode(p.read_bytes())
- elif b := _as_bytesio(image):
- b64 = b64encode(b.read())
- else:
- raise Exception('images must be a list of bytes, path-like objects, or file-like objects')
-
- return b64.decode('utf-8')
-
-
-def _parse_modelfile(modelfile, cb, base=None):
- base = Path.cwd() if base is None else base
-
- out = io.StringIO()
- for line in io.StringIO(modelfile):
- command, _, args = line.partition(' ')
- if command.upper() in ['FROM', 'ADAPTER']:
- path = Path(args).expanduser()
- path = path if path.is_absolute() else base / path
- if path.exists():
- args = f'@{cb(path)}'
-
- print(command, args, file=out)
- return out.getvalue()
-
-
-def _as_path(s):
- if isinstance(s, str) or isinstance(s, Path):
- return Path(s)
- return None
-
-def _as_bytesio(s):
- if isinstance(s, io.BytesIO):
- return s
- elif isinstance(s, bytes):
- return io.BytesIO(s)
- return None
diff --git a/ollama/client_test.py b/ollama/client_test.py
deleted file mode 100644
index cb563dc..0000000
--- a/ollama/client_test.py
+++ /dev/null
@@ -1,292 +0,0 @@
-import pytest
-import os
-import io
-import types
-import tempfile
-from pathlib import Path
-from ollama.client import Client
-from pytest_httpserver import HTTPServer, URIPattern
-from werkzeug.wrappers import Response
-from PIL import Image
-
-
-class PrefixPattern(URIPattern):
- def __init__(self, prefix: str):
- self.prefix = prefix
-
- def match(self, uri):
- return uri.startswith(self.prefix)
-
-
-def test_client_chat(httpserver: HTTPServer):
- httpserver.expect_ordered_request('/api/chat', method='POST', json={
- 'model': 'dummy',
- 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
- 'stream': False,
- 'format': '',
- 'options': {},
- }).respond_with_json({})
-
- client = Client(httpserver.url_for('/'))
- response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
- assert isinstance(response, dict)
-
-
-def test_client_chat_stream(httpserver: HTTPServer):
- httpserver.expect_ordered_request('/api/chat', method='POST', json={
- 'model': 'dummy',
- 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
- 'stream': True,
- 'format': '',
- 'options': {},
- }).respond_with_json({})
-
- client = Client(httpserver.url_for('/'))
- response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
- assert isinstance(response, types.GeneratorType)
-
-
-def test_client_chat_images(httpserver: HTTPServer):
- httpserver.expect_ordered_request('/api/chat', method='POST', json={
- 'model': 'dummy',
- 'messages': [
- {
- 'role': 'user',
- 'content': 'Why is the sky blue?',
- 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
- },
- ],
- 'stream': False,
- 'format': '',
- 'options': {},
- }).respond_with_json({})
-
- client = Client(httpserver.url_for('/'))
-
- with io.BytesIO() as b:
- Image.new('RGB', (1, 1)).save(b, 'PNG')
- response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
- assert isinstance(response, dict)
-
-
-def test_client_generate(httpserver: HTTPServer):
- httpserver.expect_ordered_request('/api/generate', method='POST', json={
- 'model': 'dummy',
- 'prompt': 'Why is the sky blue?',
- 'system': '',
- 'template': '',
- 'context': [],
- 'stream': False,
- 'raw': False,
- 'images': [],
- 'format': '',
- 'options': {},
- }).respond_with_json({})
-
- client = Client(httpserver.url_for('/'))
- response = client.generate('dummy', 'Why is the sky blue?')
- assert isinstance(response, dict)
-
-
-def test_client_generate_stream(httpserver: HTTPServer):
- httpserver.expect_ordered_request('/api/generate', method='POST', json={
- 'model': 'dummy',
- 'prompt': 'Why is the sky blue?',
- 'system': '',
- 'template': '',
- 'context': [],
- 'stream': True,
- 'raw': False,
- 'images': [],
- 'format': '',
- 'options': {},
- }).respond_with_json({})
-
- client = Client(httpserver.url_for('/'))
- response = client.generate('dummy', 'Why is the sky blue?', stream=True)
- assert isinstance(response, types.GeneratorType)
-
-
-def test_client_generate_images(httpserver: HTTPServer):
- httpserver.expect_ordered_request('/api/generate', method='POST', json={
- 'model': 'dummy',
- 'prompt': 'Why is the sky blue?',
- 'system': '',
- 'template': '',
- 'context': [],
- 'stream': False,
- 'raw': False,
- 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
- 'format': '',
- 'options': {},
- }).respond_with_json({})
-
- client = Client(httpserver.url_for('/'))
-
- with tempfile.NamedTemporaryFile() as temp:
- Image.new('RGB', (1, 1)).save(temp, 'PNG')
- response = client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
- assert isinstance(response, dict)
-
-
-def test_client_pull(httpserver: HTTPServer):
- httpserver.expect_ordered_request('/api/pull', method='POST', json={
- 'model': 'dummy',
- 'insecure': False,
- 'stream': False,
- }).respond_with_json({})
-
- client = Client(httpserver.url_for('/'))
- response = client.pull('dummy')
- assert isinstance(response, dict)
-
-
-def test_client_pull_stream(httpserver: HTTPServer):
- httpserver.expect_ordered_request('/api/pull', method='POST', json={
- 'model': 'dummy',
- 'insecure': False,
- 'stream': True,
- }).respond_with_json({})
-
- client = Client(httpserver.url_for('/'))
- response = client.pull('dummy', stream=True)
- assert isinstance(response, types.GeneratorType)
-
-
-def test_client_push(httpserver: HTTPServer):
- httpserver.expect_ordered_request('/api/push', method='POST', json={
- 'model': 'dummy',
- 'insecure': False,
- 'stream': False,
- }).respond_with_json({})
-
- client = Client(httpserver.url_for('/'))
- response = client.push('dummy')
- assert isinstance(response, dict)
-
-
-def test_client_push_stream(httpserver: HTTPServer):
- httpserver.expect_ordered_request('/api/push', method='POST', json={
- 'model': 'dummy',
- 'insecure': False,
- 'stream': True,
- }).respond_with_json({})
-
- client = Client(httpserver.url_for('/'))
- response = client.push('dummy', stream=True)
- assert isinstance(response, types.GeneratorType)
-
-
-def test_client_create_path(httpserver: HTTPServer):
- httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
- httpserver.expect_ordered_request('/api/create', method='POST', json={
- 'model': 'dummy',
- 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
- 'stream': False,
- }).respond_with_json({})
-
- client = Client(httpserver.url_for('/'))
-
- with tempfile.NamedTemporaryFile() as modelfile:
- with tempfile.NamedTemporaryFile() as blob:
- modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
- modelfile.flush()
-
- response = client.create('dummy', path=modelfile.name)
- assert isinstance(response, dict)
-
-
-def test_client_create_path_relative(httpserver: HTTPServer):
- httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
- httpserver.expect_ordered_request('/api/create', method='POST', json={
- 'model': 'dummy',
- 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
- 'stream': False,
- }).respond_with_json({})
-
- client = Client(httpserver.url_for('/'))
-
- with tempfile.NamedTemporaryFile() as modelfile:
- with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
- modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
- modelfile.flush()
-
- response = client.create('dummy', path=modelfile.name)
- assert isinstance(response, dict)
-
-
-@pytest.fixture
-def userhomedir():
- with tempfile.TemporaryDirectory() as temp:
- home = os.getenv('HOME', '')
- os.environ['HOME'] = temp
- yield Path(temp)
- os.environ['HOME'] = home
-
-
-def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
- httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
- httpserver.expect_ordered_request('/api/create', method='POST', json={
- 'model': 'dummy',
- 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
- 'stream': False,
- }).respond_with_json({})
-
- client = Client(httpserver.url_for('/'))
-
- with tempfile.NamedTemporaryFile() as modelfile:
- with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
- modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
- modelfile.flush()
-
- response = client.create('dummy', path=modelfile.name)
- assert isinstance(response, dict)
-
-
-def test_client_create_modelfile(httpserver: HTTPServer):
- httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
- httpserver.expect_ordered_request('/api/create', method='POST', json={
- 'model': 'dummy',
- 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
- 'stream': False,
- }).respond_with_json({})
-
- client = Client(httpserver.url_for('/'))
-
- with tempfile.NamedTemporaryFile() as blob:
- response = client.create('dummy', modelfile=f'FROM {blob.name}')
- assert isinstance(response, dict)
-
-
-def test_client_create_from_library(httpserver: HTTPServer):
- httpserver.expect_ordered_request('/api/create', method='POST', json={
- 'model': 'dummy',
- 'modelfile': 'FROM llama2\n',
- 'stream': False,
- }).respond_with_json({})
-
- client = Client(httpserver.url_for('/'))
-
- response = client.create('dummy', modelfile='FROM llama2')
- assert isinstance(response, dict)
-
-
-def test_client_create_blob(httpserver: HTTPServer):
- httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404))
- httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201))
-
- client = Client(httpserver.url_for('/'))
-
- with tempfile.NamedTemporaryFile() as blob:
- response = client.create_blob(blob.name)
- assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
-
-
-def test_client_create_blob_exists(httpserver: HTTPServer):
- httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
-
- client = Client(httpserver.url_for('/'))
-
- with tempfile.NamedTemporaryFile() as blob:
- response = client.create_blob(blob.name)
- assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
diff --git a/poetry.lock b/poetry.lock
index a61b83c..3db7b00 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -387,6 +387,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
[package.extras]
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
+[[package]]
+name = "pytest-asyncio"
+version = "0.23.2"
+description = "Pytest support for asyncio"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "pytest-asyncio-0.23.2.tar.gz", hash = "sha256:c16052382554c7b22d48782ab3438d5b10f8cf7a4bdcae7f0f67f097d95beecc"},
+ {file = "pytest_asyncio-0.23.2-py3-none-any.whl", hash = "sha256:ea9021364e32d58f0be43b91c6233fb8d2224ccef2398d6837559e587682808f"},
+]
+
+[package.dependencies]
+pytest = ">=7.0.0"
+
+[package.extras]
+docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
+testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
+
[[package]]
name = "pytest-cov"
version = "4.1.0"
@@ -498,4 +516,4 @@ watchdog = ["watchdog (>=2.3)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.8"
-content-hash = "b9f64e1a5795a417d2dbff7286360f8d3f8f10fdfa9580411940d144c2561e92"
+content-hash = "9416a897c95d3c80cf1bfd3cc61cd19f0143c9bd0bc7c219fcb31ee27c497c9d"
diff --git a/pyproject.toml b/pyproject.toml
index 5f3858a..76b5e15 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,9 +1,12 @@
[tool.poetry]
name = "ollama"
-version = "0.1.0"
+version = "0.0.0"
description = "The official Python client for Ollama."
authors = ["Ollama "]
+license = "MIT"
readme = "README.md"
+homepage = "https://ollama.ai"
+repository = "https://github.com/jmorganca/ollama-python"
[tool.poetry.dependencies]
python = "^3.8"
@@ -11,12 +14,18 @@ httpx = "^0.25.2"
[tool.poetry.group.dev.dependencies]
pytest = "^7.4.3"
+pytest-asyncio = "^0.23.2"
pytest-cov = "^4.1.0"
pytest-httpserver = "^1.0.8"
pillow = "^10.1.0"
ruff = "^0.1.8"
+[build-system]
+requires = ["poetry-core"]
+build-backend = "poetry.core.masonry.api"
+
[tool.ruff]
+line-length = 999
indent-width = 2
[tool.ruff.format]
@@ -26,7 +35,3 @@ indent-style = "space"
[tool.ruff.lint]
select = ["E", "F", "B"]
ignore = ["E501"]
-
-[build-system]
-requires = ["poetry-core"]
-build-backend = "poetry.core.masonry.api"
diff --git a/tests/test_client.py b/tests/test_client.py
new file mode 100644
index 0000000..fe151dc
--- /dev/null
+++ b/tests/test_client.py
@@ -0,0 +1,779 @@
+import os
+import io
+import json
+import types
+import pytest
+import tempfile
+from pathlib import Path
+from pytest_httpserver import HTTPServer, URIPattern
+from werkzeug.wrappers import Request, Response
+from PIL import Image
+
+from ollama._client import Client, AsyncClient
+
+
+class PrefixPattern(URIPattern):
+ def __init__(self, prefix: str):
+ self.prefix = prefix
+
+ def match(self, uri):
+ return uri.startswith(self.prefix)
+
+
+def test_client_chat(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/chat',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
+ 'stream': False,
+ 'format': '',
+ 'options': {},
+ },
+ ).respond_with_json(
+ {
+ 'model': 'dummy',
+ 'message': {
+ 'role': 'assistant',
+ 'content': "I don't know.",
+ },
+ }
+ )
+
+ client = Client(httpserver.url_for('/'))
+ response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
+ assert response['model'] == 'dummy'
+ assert response['message']['role'] == 'assistant'
+ assert response['message']['content'] == "I don't know."
+
+
+def test_client_chat_stream(httpserver: HTTPServer):
+ def stream_handler(_: Request):
+ def generate():
+ for message in ['I ', "don't ", 'know.']:
+ yield (
+ json.dumps(
+ {
+ 'model': 'dummy',
+ 'message': {
+ 'role': 'assistant',
+ 'content': message,
+ },
+ }
+ )
+ + '\n'
+ )
+
+ return Response(generate())
+
+ httpserver.expect_ordered_request(
+ '/api/chat',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
+ 'stream': True,
+ 'format': '',
+ 'options': {},
+ },
+ ).respond_with_handler(stream_handler)
+
+ client = Client(httpserver.url_for('/'))
+ response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
+ for part in response:
+ assert part['message']['role'] in 'assistant'
+ assert part['message']['content'] in ['I ', "don't ", 'know.']
+
+
+def test_client_chat_images(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/chat',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'messages': [
+ {
+ 'role': 'user',
+ 'content': 'Why is the sky blue?',
+ 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
+ },
+ ],
+ 'stream': False,
+ 'format': '',
+ 'options': {},
+ },
+ ).respond_with_json(
+ {
+ 'model': 'dummy',
+ 'message': {
+ 'role': 'assistant',
+ 'content': "I don't know.",
+ },
+ }
+ )
+
+ client = Client(httpserver.url_for('/'))
+
+ with io.BytesIO() as b:
+ Image.new('RGB', (1, 1)).save(b, 'PNG')
+ response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
+ assert response['model'] == 'dummy'
+ assert response['message']['role'] == 'assistant'
+ assert response['message']['content'] == "I don't know."
+
+
+def test_client_generate(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/generate',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'prompt': 'Why is the sky blue?',
+ 'system': '',
+ 'template': '',
+ 'context': [],
+ 'stream': False,
+ 'raw': False,
+ 'images': [],
+ 'format': '',
+ 'options': {},
+ },
+ ).respond_with_json(
+ {
+ 'model': 'dummy',
+ 'response': 'Because it is.',
+ }
+ )
+
+ client = Client(httpserver.url_for('/'))
+ response = client.generate('dummy', 'Why is the sky blue?')
+ assert response['model'] == 'dummy'
+ assert response['response'] == 'Because it is.'
+
+
+def test_client_generate_stream(httpserver: HTTPServer):
+ def stream_handler(_: Request):
+ def generate():
+ for message in ['Because ', 'it ', 'is.']:
+ yield (
+ json.dumps(
+ {
+ 'model': 'dummy',
+ 'response': message,
+ }
+ )
+ + '\n'
+ )
+
+ return Response(generate())
+
+ httpserver.expect_ordered_request(
+ '/api/generate',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'prompt': 'Why is the sky blue?',
+ 'system': '',
+ 'template': '',
+ 'context': [],
+ 'stream': True,
+ 'raw': False,
+ 'images': [],
+ 'format': '',
+ 'options': {},
+ },
+ ).respond_with_handler(stream_handler)
+
+ client = Client(httpserver.url_for('/'))
+ response = client.generate('dummy', 'Why is the sky blue?', stream=True)
+ for part in response:
+ assert part['model'] == 'dummy'
+ assert part['response'] in ['Because ', 'it ', 'is.']
+
+
+def test_client_generate_images(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/generate',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'prompt': 'Why is the sky blue?',
+ 'system': '',
+ 'template': '',
+ 'context': [],
+ 'stream': False,
+ 'raw': False,
+ 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
+ 'format': '',
+ 'options': {},
+ },
+ ).respond_with_json(
+ {
+ 'model': 'dummy',
+ 'response': 'Because it is.',
+ }
+ )
+
+ client = Client(httpserver.url_for('/'))
+
+ with tempfile.NamedTemporaryFile() as temp:
+ Image.new('RGB', (1, 1)).save(temp, 'PNG')
+ response = client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
+ assert response['model'] == 'dummy'
+ assert response['response'] == 'Because it is.'
+
+
+def test_client_pull(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/pull',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'insecure': False,
+ 'stream': False,
+ },
+ ).respond_with_json(
+ {
+ 'status': 'success',
+ }
+ )
+
+ client = Client(httpserver.url_for('/'))
+ response = client.pull('dummy')
+ assert response['status'] == 'success'
+
+
+def test_client_pull_stream(httpserver: HTTPServer):
+ def stream_handler(_: Request):
+ def generate():
+ yield json.dumps({'status': 'pulling manifest'}) + '\n'
+ yield json.dumps({'status': 'verifying sha256 digest'}) + '\n'
+ yield json.dumps({'status': 'writing manifest'}) + '\n'
+ yield json.dumps({'status': 'removing any unused layers'}) + '\n'
+ yield json.dumps({'status': 'success'}) + '\n'
+
+ return Response(generate())
+
+ httpserver.expect_ordered_request(
+ '/api/pull',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'insecure': False,
+ 'stream': True,
+ },
+ ).respond_with_json({})
+
+ client = Client(httpserver.url_for('/'))
+ response = client.pull('dummy', stream=True)
+ assert isinstance(response, types.GeneratorType)
+
+
+def test_client_push(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/push',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'insecure': False,
+ 'stream': False,
+ },
+ ).respond_with_json({})
+
+ client = Client(httpserver.url_for('/'))
+ response = client.push('dummy')
+ assert isinstance(response, dict)
+
+
+def test_client_push_stream(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/push',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'insecure': False,
+ 'stream': True,
+ },
+ ).respond_with_json({})
+
+ client = Client(httpserver.url_for('/'))
+ response = client.push('dummy', stream=True)
+ assert isinstance(response, types.GeneratorType)
+
+
+def test_client_create_path(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
+ httpserver.expect_ordered_request(
+ '/api/create',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
+ 'stream': False,
+ },
+ ).respond_with_json({})
+
+ client = Client(httpserver.url_for('/'))
+
+ with tempfile.NamedTemporaryFile() as modelfile:
+ with tempfile.NamedTemporaryFile() as blob:
+ modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
+ modelfile.flush()
+
+ response = client.create('dummy', path=modelfile.name)
+ assert isinstance(response, dict)
+
+
+def test_client_create_path_relative(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
+ httpserver.expect_ordered_request(
+ '/api/create',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
+ 'stream': False,
+ },
+ ).respond_with_json({})
+
+ client = Client(httpserver.url_for('/'))
+
+ with tempfile.NamedTemporaryFile() as modelfile:
+ with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
+ modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
+ modelfile.flush()
+
+ response = client.create('dummy', path=modelfile.name)
+ assert isinstance(response, dict)
+
+
+@pytest.fixture
+def userhomedir():
+ with tempfile.TemporaryDirectory() as temp:
+ home = os.getenv('HOME', '')
+ os.environ['HOME'] = temp
+ yield Path(temp)
+ os.environ['HOME'] = home
+
+
+def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
+ httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
+ httpserver.expect_ordered_request(
+ '/api/create',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
+ 'stream': False,
+ },
+ ).respond_with_json({})
+
+ client = Client(httpserver.url_for('/'))
+
+ with tempfile.NamedTemporaryFile() as modelfile:
+ with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
+ modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
+ modelfile.flush()
+
+ response = client.create('dummy', path=modelfile.name)
+ assert isinstance(response, dict)
+
+
+def test_client_create_modelfile(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
+ httpserver.expect_ordered_request(
+ '/api/create',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
+ 'stream': False,
+ },
+ ).respond_with_json({})
+
+ client = Client(httpserver.url_for('/'))
+
+ with tempfile.NamedTemporaryFile() as blob:
+ response = client.create('dummy', modelfile=f'FROM {blob.name}')
+ assert isinstance(response, dict)
+
+
+def test_client_create_from_library(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/create',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'modelfile': 'FROM llama2\n',
+ 'stream': False,
+ },
+ ).respond_with_json({})
+
+ client = Client(httpserver.url_for('/'))
+
+ response = client.create('dummy', modelfile='FROM llama2')
+ assert isinstance(response, dict)
+
+
+def test_client_create_blob(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404))
+ httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201))
+
+ client = Client(httpserver.url_for('/'))
+
+ with tempfile.NamedTemporaryFile() as blob:
+ response = client._create_blob(blob.name)
+ assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
+
+
+def test_client_create_blob_exists(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
+
+ client = Client(httpserver.url_for('/'))
+
+ with tempfile.NamedTemporaryFile() as blob:
+ response = client._create_blob(blob.name)
+ assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
+
+
+@pytest.mark.asyncio
+async def test_async_client_chat(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/chat',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
+ 'stream': False,
+ 'format': '',
+ 'options': {},
+ },
+ ).respond_with_json({})
+
+ client = AsyncClient(httpserver.url_for('/'))
+ response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
+ assert isinstance(response, dict)
+
+
+@pytest.mark.asyncio
+async def test_async_client_chat_stream(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/chat',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
+ 'stream': True,
+ 'format': '',
+ 'options': {},
+ },
+ ).respond_with_json({})
+
+ client = AsyncClient(httpserver.url_for('/'))
+ response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
+ assert isinstance(response, types.AsyncGeneratorType)
+
+
+@pytest.mark.asyncio
+async def test_async_client_chat_images(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/chat',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'messages': [
+ {
+ 'role': 'user',
+ 'content': 'Why is the sky blue?',
+ 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
+ },
+ ],
+ 'stream': False,
+ 'format': '',
+ 'options': {},
+ },
+ ).respond_with_json({})
+
+ client = AsyncClient(httpserver.url_for('/'))
+
+ with io.BytesIO() as b:
+ Image.new('RGB', (1, 1)).save(b, 'PNG')
+ response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
+ assert isinstance(response, dict)
+
+
+@pytest.mark.asyncio
+async def test_async_client_generate(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/generate',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'prompt': 'Why is the sky blue?',
+ 'system': '',
+ 'template': '',
+ 'context': [],
+ 'stream': False,
+ 'raw': False,
+ 'images': [],
+ 'format': '',
+ 'options': {},
+ },
+ ).respond_with_json({})
+
+ client = AsyncClient(httpserver.url_for('/'))
+ response = await client.generate('dummy', 'Why is the sky blue?')
+ assert isinstance(response, dict)
+
+
+@pytest.mark.asyncio
+async def test_async_client_generate_stream(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/generate',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'prompt': 'Why is the sky blue?',
+ 'system': '',
+ 'template': '',
+ 'context': [],
+ 'stream': True,
+ 'raw': False,
+ 'images': [],
+ 'format': '',
+ 'options': {},
+ },
+ ).respond_with_json({})
+
+ client = AsyncClient(httpserver.url_for('/'))
+ response = await client.generate('dummy', 'Why is the sky blue?', stream=True)
+ assert isinstance(response, types.AsyncGeneratorType)
+
+
+@pytest.mark.asyncio
+async def test_async_client_generate_images(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/generate',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'prompt': 'Why is the sky blue?',
+ 'system': '',
+ 'template': '',
+ 'context': [],
+ 'stream': False,
+ 'raw': False,
+ 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
+ 'format': '',
+ 'options': {},
+ },
+ ).respond_with_json({})
+
+ client = AsyncClient(httpserver.url_for('/'))
+
+ with tempfile.NamedTemporaryFile() as temp:
+ Image.new('RGB', (1, 1)).save(temp, 'PNG')
+ response = await client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
+ assert isinstance(response, dict)
+
+
+@pytest.mark.asyncio
+async def test_async_client_pull(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/pull',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'insecure': False,
+ 'stream': False,
+ },
+ ).respond_with_json({})
+
+ client = AsyncClient(httpserver.url_for('/'))
+ response = await client.pull('dummy')
+ assert isinstance(response, dict)
+
+
+@pytest.mark.asyncio
+async def test_async_client_pull_stream(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/pull',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'insecure': False,
+ 'stream': True,
+ },
+ ).respond_with_json({})
+
+ client = AsyncClient(httpserver.url_for('/'))
+ response = await client.pull('dummy', stream=True)
+ assert isinstance(response, types.AsyncGeneratorType)
+
+
+@pytest.mark.asyncio
+async def test_async_client_push(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/push',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'insecure': False,
+ 'stream': False,
+ },
+ ).respond_with_json({})
+
+ client = AsyncClient(httpserver.url_for('/'))
+ response = await client.push('dummy')
+ assert isinstance(response, dict)
+
+
+@pytest.mark.asyncio
+async def test_async_client_push_stream(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/push',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'insecure': False,
+ 'stream': True,
+ },
+ ).respond_with_json({})
+
+ client = AsyncClient(httpserver.url_for('/'))
+ response = await client.push('dummy', stream=True)
+ assert isinstance(response, types.AsyncGeneratorType)
+
+
+@pytest.mark.asyncio
+async def test_async_client_create_path(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
+ httpserver.expect_ordered_request(
+ '/api/create',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
+ 'stream': False,
+ },
+ ).respond_with_json({})
+
+ client = AsyncClient(httpserver.url_for('/'))
+
+ with tempfile.NamedTemporaryFile() as modelfile:
+ with tempfile.NamedTemporaryFile() as blob:
+ modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
+ modelfile.flush()
+
+ response = await client.create('dummy', path=modelfile.name)
+ assert isinstance(response, dict)
+
+
+@pytest.mark.asyncio
+async def test_async_client_create_path_relative(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
+ httpserver.expect_ordered_request(
+ '/api/create',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
+ 'stream': False,
+ },
+ ).respond_with_json({})
+
+ client = AsyncClient(httpserver.url_for('/'))
+
+ with tempfile.NamedTemporaryFile() as modelfile:
+ with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
+ modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
+ modelfile.flush()
+
+ response = await client.create('dummy', path=modelfile.name)
+ assert isinstance(response, dict)
+
+
+@pytest.mark.asyncio
+async def test_async_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
+ httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
+ httpserver.expect_ordered_request(
+ '/api/create',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
+ 'stream': False,
+ },
+ ).respond_with_json({})
+
+ client = AsyncClient(httpserver.url_for('/'))
+
+ with tempfile.NamedTemporaryFile() as modelfile:
+ with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
+ modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
+ modelfile.flush()
+
+ response = await client.create('dummy', path=modelfile.name)
+ assert isinstance(response, dict)
+
+
+@pytest.mark.asyncio
+async def test_async_client_create_modelfile(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
+ httpserver.expect_ordered_request(
+ '/api/create',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
+ 'stream': False,
+ },
+ ).respond_with_json({})
+
+ client = AsyncClient(httpserver.url_for('/'))
+
+ with tempfile.NamedTemporaryFile() as blob:
+ response = await client.create('dummy', modelfile=f'FROM {blob.name}')
+ assert isinstance(response, dict)
+
+
+@pytest.mark.asyncio
+async def test_async_client_create_from_library(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(
+ '/api/create',
+ method='POST',
+ json={
+ 'model': 'dummy',
+ 'modelfile': 'FROM llama2\n',
+ 'stream': False,
+ },
+ ).respond_with_json({})
+
+ client = AsyncClient(httpserver.url_for('/'))
+
+ response = await client.create('dummy', modelfile='FROM llama2')
+ assert isinstance(response, dict)
+
+
+@pytest.mark.asyncio
+async def test_async_client_create_blob(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404))
+ httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201))
+
+ client = AsyncClient(httpserver.url_for('/'))
+
+ with tempfile.NamedTemporaryFile() as blob:
+ response = await client._create_blob(blob.name)
+ assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
+
+
+@pytest.mark.asyncio
+async def test_async_client_create_blob_exists(httpserver: HTTPServer):
+ httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
+
+ client = AsyncClient(httpserver.url_for('/'))
+
+ with tempfile.NamedTemporaryFile() as blob:
+ response = await client._create_blob(blob.name)
+ assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'