Merge pull request #417 from ollama/pdevine/create-api

add create api changes
This commit is contained in:
Patrick Devine 2025-01-13 17:27:25 -08:00 committed by GitHub
commit 2cad1f5428
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 218 additions and 302 deletions

View File

@ -1,5 +1,4 @@
import os
import io
import json
import platform
import ipaddress
@ -19,6 +18,8 @@ from typing import (
TypeVar,
Union,
overload,
Dict,
List,
)
import sys
@ -62,7 +63,6 @@ from ollama._types import (
ProgressResponse,
PullRequest,
PushRequest,
RequestError,
ResponseError,
ShowRequest,
ShowResponse,
@ -476,10 +476,16 @@ class Client(BaseClient):
def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
*,
quantize: Optional[str] = None,
from_: Optional[str] = None,
files: Optional[Dict[str, str]] = None,
adapters: Optional[Dict[str, str]] = None,
template: Optional[str] = None,
license: Optional[Union[str, List[str]]] = None,
system: Optional[str] = None,
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
stream: Literal[False] = False,
) -> ProgressResponse: ...
@ -487,20 +493,32 @@ class Client(BaseClient):
def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
*,
quantize: Optional[str] = None,
from_: Optional[str] = None,
files: Optional[Dict[str, str]] = None,
adapters: Optional[Dict[str, str]] = None,
template: Optional[str] = None,
license: Optional[Union[str, List[str]]] = None,
system: Optional[str] = None,
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
stream: Literal[True] = True,
) -> Iterator[ProgressResponse]: ...
def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
*,
quantize: Optional[str] = None,
from_: Optional[str] = None,
files: Optional[Dict[str, str]] = None,
adapters: Optional[Dict[str, str]] = None,
template: Optional[str] = None,
license: Optional[Union[str, List[str]]] = None,
system: Optional[str] = None,
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
stream: bool = False,
) -> Union[ProgressResponse, Iterator[ProgressResponse]]:
"""
@ -508,45 +526,27 @@ class Client(BaseClient):
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
if (realpath := _as_path(path)) and realpath.exists():
modelfile = self._parse_modelfile(realpath.read_text(), base=realpath.parent)
elif modelfile:
modelfile = self._parse_modelfile(modelfile)
else:
raise RequestError('must provide either path or modelfile')
return self._request(
ProgressResponse,
'POST',
'/api/create',
json=CreateRequest(
model=model,
modelfile=modelfile,
stream=stream,
quantize=quantize,
from_=from_,
files=files,
adapters=adapters,
license=license,
template=template,
system=system,
parameters=parameters,
messages=messages,
).model_dump(exclude_none=True),
stream=stream,
)
def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
base = Path.cwd() if base is None else base
out = io.StringIO()
for line in io.StringIO(modelfile):
command, _, args = line.partition(' ')
if command.upper() not in ['FROM', 'ADAPTER']:
print(line, end='', file=out)
continue
path = Path(args.strip()).expanduser()
path = path if path.is_absolute() else base / path
if path.exists():
args = f'@{self._create_blob(path)}\n'
print(command, args, end='', file=out)
return out.getvalue()
def _create_blob(self, path: Union[str, Path]) -> str:
def create_blob(self, path: Union[str, Path]) -> str:
sha256sum = sha256()
with open(path, 'rb') as r:
while True:
@ -978,31 +978,49 @@ class AsyncClient(BaseClient):
async def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
*,
quantize: Optional[str] = None,
stream: Literal[False] = False,
from_: Optional[str] = None,
files: Optional[Dict[str, str]] = None,
adapters: Optional[Dict[str, str]] = None,
template: Optional[str] = None,
license: Optional[Union[str, List[str]]] = None,
system: Optional[str] = None,
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
stream: Literal[True] = True,
) -> ProgressResponse: ...
@overload
async def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
*,
quantize: Optional[str] = None,
from_: Optional[str] = None,
files: Optional[Dict[str, str]] = None,
adapters: Optional[Dict[str, str]] = None,
template: Optional[str] = None,
license: Optional[Union[str, List[str]]] = None,
system: Optional[str] = None,
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
stream: Literal[True] = True,
) -> AsyncIterator[ProgressResponse]: ...
async def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
*,
quantize: Optional[str] = None,
from_: Optional[str] = None,
files: Optional[Dict[str, str]] = None,
adapters: Optional[Dict[str, str]] = None,
template: Optional[str] = None,
license: Optional[Union[str, List[str]]] = None,
system: Optional[str] = None,
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
stream: bool = False,
) -> Union[ProgressResponse, AsyncIterator[ProgressResponse]]:
"""
@ -1010,12 +1028,6 @@ class AsyncClient(BaseClient):
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
if (realpath := _as_path(path)) and realpath.exists():
modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent)
elif modelfile:
modelfile = await self._parse_modelfile(modelfile)
else:
raise RequestError('must provide either path or modelfile')
return await self._request(
ProgressResponse,
@ -1023,32 +1035,21 @@ class AsyncClient(BaseClient):
'/api/create',
json=CreateRequest(
model=model,
modelfile=modelfile,
stream=stream,
quantize=quantize,
from_=from_,
files=files,
adapters=adapters,
license=license,
template=template,
system=system,
parameters=parameters,
messages=messages,
).model_dump(exclude_none=True),
stream=stream,
)
async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
base = Path.cwd() if base is None else base
out = io.StringIO()
for line in io.StringIO(modelfile):
command, _, args = line.partition(' ')
if command.upper() not in ['FROM', 'ADAPTER']:
print(line, end='', file=out)
continue
path = Path(args.strip()).expanduser()
path = path if path.is_absolute() else base / path
if path.exists():
args = f'@{await self._create_blob(path)}\n'
print(command, args, end='', file=out)
return out.getvalue()
async def _create_blob(self, path: Union[str, Path]) -> str:
async def create_blob(self, path: Union[str, Path]) -> str:
sha256sum = sha256()
with open(path, 'rb') as r:
while True:

View File

@ -2,7 +2,7 @@ import json
from base64 import b64decode, b64encode
from pathlib import Path
from datetime import datetime
from typing import Any, Mapping, Optional, Union, Sequence
from typing import Any, Mapping, Optional, Union, Sequence, Dict, List
from pydantic.json_schema import JsonSchemaValue
from typing_extensions import Annotated, Literal
@ -401,13 +401,25 @@ class PushRequest(BaseStreamableRequest):
class CreateRequest(BaseStreamableRequest):
@model_serializer(mode='wrap')
def serialize_model(self, nxt):
output = nxt(self)
if 'from_' in output:
output['from'] = output.pop('from_')
return output
"""
Request to create a new model.
"""
modelfile: Optional[str] = None
quantize: Optional[str] = None
from_: Optional[str] = None
files: Optional[Dict[str, str]] = None
adapters: Optional[Dict[str, str]] = None
template: Optional[str] = None
license: Optional[Union[str, List[str]]] = None
system: Optional[str] = None
parameters: Optional[Union[Mapping[str, Any], Options]] = None
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None
class ModelDetails(SubscriptableBaseModel):

32
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.0.1 and should not be changed by hand.
[[package]]
name = "annotated-types"
@ -6,6 +6,7 @@ version = "0.7.0"
description = "Reusable constraint types to use with typing.Annotated"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"},
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
@ -20,6 +21,7 @@ version = "4.5.2"
description = "High level compatibility layer for multiple asynchronous event loop implementations"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "anyio-4.5.2-py3-none-any.whl", hash = "sha256:c011ee36bc1e8ba40e5a81cb9df91925c218fe9b778554e0b56a21e1b5d4716f"},
{file = "anyio-4.5.2.tar.gz", hash = "sha256:23009af4ed04ce05991845451e11ef02fc7c5ed29179ac9a420e5ad0ac7ddc5b"},
@ -42,6 +44,7 @@ version = "2024.8.30"
description = "Python package for providing Mozilla's CA Bundle."
optional = false
python-versions = ">=3.6"
groups = ["main"]
files = [
{file = "certifi-2024.8.30-py3-none-any.whl", hash = "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8"},
{file = "certifi-2024.8.30.tar.gz", hash = "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9"},
@ -53,6 +56,8 @@ version = "0.4.6"
description = "Cross-platform colored terminal text."
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
groups = ["dev"]
markers = "sys_platform == \"win32\""
files = [
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
@ -64,6 +69,7 @@ version = "7.6.1"
description = "Code coverage measurement for Python"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "coverage-7.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b06079abebbc0e89e6163b8e8f0e16270124c154dc6e4a47b413dd538859af16"},
{file = "coverage-7.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cf4b19715bccd7ee27b6b120e7e9dd56037b9c0681dcc1adc9ba9db3d417fa36"},
@ -151,6 +157,8 @@ version = "1.2.2"
description = "Backport of PEP 654 (exception groups)"
optional = false
python-versions = ">=3.7"
groups = ["main", "dev"]
markers = "python_version < \"3.11\""
files = [
{file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
@ -165,6 +173,7 @@ version = "0.14.0"
description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1"
optional = false
python-versions = ">=3.7"
groups = ["main"]
files = [
{file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"},
{file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"},
@ -176,6 +185,7 @@ version = "1.0.6"
description = "A minimal low-level HTTP client."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "httpcore-1.0.6-py3-none-any.whl", hash = "sha256:27b59625743b85577a8c0e10e55b50b5368a4f2cfe8cc7bcfa9cf00829c2682f"},
{file = "httpcore-1.0.6.tar.gz", hash = "sha256:73f6dbd6eb8c21bbf7ef8efad555481853f5f6acdeaff1edb0694289269ee17f"},
@ -197,6 +207,7 @@ version = "0.27.2"
description = "The next generation HTTP client."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"},
{file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"},
@ -222,6 +233,7 @@ version = "3.10"
description = "Internationalized Domain Names in Applications (IDNA)"
optional = false
python-versions = ">=3.6"
groups = ["main"]
files = [
{file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"},
{file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"},
@ -236,6 +248,7 @@ version = "2.0.0"
description = "brain-dead simple config-ini parsing"
optional = false
python-versions = ">=3.7"
groups = ["dev"]
files = [
{file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
@ -247,6 +260,7 @@ version = "2.1.5"
description = "Safely add untrusted strings to HTML/XML markup."
optional = false
python-versions = ">=3.7"
groups = ["dev"]
files = [
{file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"},
{file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5"},
@ -316,6 +330,7 @@ version = "24.1"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"},
{file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"},
@ -327,6 +342,7 @@ version = "1.5.0"
description = "plugin and hook calling mechanisms for python"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"},
{file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"},
@ -342,6 +358,7 @@ version = "2.9.2"
description = "Data validation using Python type hints"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "pydantic-2.9.2-py3-none-any.whl", hash = "sha256:f048cec7b26778210e28a0459867920654d48e5e62db0958433636cde4254f12"},
{file = "pydantic-2.9.2.tar.gz", hash = "sha256:d155cef71265d1e9807ed1c32b4c8deec042a44a50a4188b25ac67ecd81a9c0f"},
@ -365,6 +382,7 @@ version = "2.23.4"
description = "Core functionality for Pydantic validation and serialization"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "pydantic_core-2.23.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b10bd51f823d891193d4717448fab065733958bdb6a6b351967bd349d48d5c9b"},
{file = "pydantic_core-2.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4fc714bdbfb534f94034efaa6eadd74e5b93c8fa6315565a222f7b6f42ca1166"},
@ -466,6 +484,7 @@ version = "8.3.4"
description = "pytest: simple powerful testing with Python"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"},
{file = "pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761"},
@ -488,6 +507,7 @@ version = "0.24.0"
description = "Pytest support for asyncio"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"},
{file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"},
@ -506,6 +526,7 @@ version = "5.0.0"
description = "Pytest plugin for measuring coverage."
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
{file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
@ -524,6 +545,7 @@ version = "1.1.0"
description = "pytest-httpserver is a httpserver for pytest"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "pytest_httpserver-1.1.0-py3-none-any.whl", hash = "sha256:7ef88be8ed3354b6784daa3daa75a422370327c634053cefb124903fa8d73a41"},
{file = "pytest_httpserver-1.1.0.tar.gz", hash = "sha256:6b1cb0199e2ed551b1b94d43f096863bbf6ae5bcd7c75c2c06845e5ce2dc8701"},
@ -538,6 +560,7 @@ version = "0.7.4"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
groups = ["dev"]
files = [
{file = "ruff-0.7.4-py3-none-linux_armv6l.whl", hash = "sha256:a4919925e7684a3f18e18243cd6bea7cfb8e968a6eaa8437971f681b7ec51478"},
{file = "ruff-0.7.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:cfb365c135b830778dda8c04fb7d4280ed0b984e1aec27f574445231e20d6c63"},
@ -565,6 +588,7 @@ version = "1.3.1"
description = "Sniff out which async library your code is running under"
optional = false
python-versions = ">=3.7"
groups = ["main"]
files = [
{file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"},
{file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"},
@ -576,6 +600,8 @@ version = "2.0.2"
description = "A lil' TOML parser"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
markers = "python_full_version <= \"3.11.0a6\""
files = [
{file = "tomli-2.0.2-py3-none-any.whl", hash = "sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38"},
{file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"},
@ -587,6 +613,7 @@ version = "4.12.2"
description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"},
{file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
@ -598,6 +625,7 @@ version = "3.0.6"
description = "The comprehensive WSGI web application library."
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "werkzeug-3.0.6-py3-none-any.whl", hash = "sha256:1bc0c2310d2fbb07b1dd1105eba2f7af72f322e1e455f2f93c993bee8c8a5f17"},
{file = "werkzeug-3.0.6.tar.gz", hash = "sha256:a8dd59d4de28ca70471a34cba79bed5f7ef2e036a76b3ab0835474246eb41f8d"},
@ -610,6 +638,6 @@ MarkupSafe = ">=2.1.1"
watchdog = ["watchdog (>=2.3)"]
[metadata]
lock-version = "2.0"
lock-version = "2.1"
python-versions = "^3.8"
content-hash = "8e93767305535b0a02f0d724edf1249fd928ff1021644eb9dc26dbfa191f6971"

View File

@ -13,6 +13,9 @@ python = "^3.8"
httpx = "^0.27.0"
pydantic = "^2.9.0"
[tool.poetry.requires-plugins]
poetry-plugin-export = ">=1.8"
[tool.poetry.group.dev.dependencies]
pytest = ">=7.4.3,<9.0.0"
pytest-asyncio = ">=0.23.2,<0.25.0"

View File

@ -536,52 +536,6 @@ def test_client_push_stream(httpserver: HTTPServer):
assert part['status'] == next(it)
def test_client_create_path(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({'status': 'success'})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile() as blob:
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert response['status'] == 'success'
def test_client_create_path_relative(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({'status': 'success'})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert response['status'] == 'success'
@pytest.fixture
def userhomedir():
with tempfile.TemporaryDirectory() as temp:
@ -591,92 +545,56 @@ def userhomedir():
os.environ['HOME'] = home
def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
def test_client_create_with_blob(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'files': {'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
'stream': False,
},
).respond_with_json({'status': 'success'})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert response['status'] == 'success'
def test_client_create_modelfile(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({'status': 'success'})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = client.create('dummy', modelfile=f'FROM {blob.name}')
with tempfile.NamedTemporaryFile():
response = client.create('dummy', files={'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'})
assert response['status'] == 'success'
def test_client_create_modelfile_roundtrip(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
def test_client_create_with_parameters_roundtrip(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>
{{.Prompt}} [/INST]"""
SYSTEM """
Use
multiline
strings.
"""
PARAMETER stop [INST]
PARAMETER stop [/INST]
PARAMETER stop <<SYS>>
PARAMETER stop <</SYS>>''',
'quantize': 'q4_k_m',
'from': 'mymodel',
'adapters': {'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
'template': '[INST] <<SYS>>{{.System}}<</SYS>>\n{{.Prompt}} [/INST]',
'license': 'this is my license',
'system': '\nUse\nmultiline\nstrings.\n',
'parameters': {'stop': ['[INST]', '[/INST]', '<<SYS>>', '<</SYS>>'], 'pi': 3.14159},
'messages': [{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
'stream': False,
},
).respond_with_json({'status': 'success'})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
with tempfile.NamedTemporaryFile():
response = client.create(
'dummy',
modelfile='\n'.join(
[
f'FROM {blob.name}',
'TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>',
'{{.Prompt}} [/INST]"""',
'SYSTEM """',
'Use',
'multiline',
'strings.',
'"""',
'PARAMETER stop [INST]',
'PARAMETER stop [/INST]',
'PARAMETER stop <<SYS>>',
'PARAMETER stop <</SYS>>',
]
),
quantize='q4_k_m',
from_='mymodel',
adapters={'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
template='[INST] <<SYS>>{{.System}}<</SYS>>\n{{.Prompt}} [/INST]',
license='this is my license',
system='\nUse\nmultiline\nstrings.\n',
parameters={'stop': ['[INST]', '[/INST]', '<<SYS>>', '<</SYS>>'], 'pi': 3.14159},
messages=[{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
stream=False,
)
assert response['status'] == 'success'
@ -687,14 +605,14 @@ def test_client_create_from_library(httpserver: HTTPServer):
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM llama2',
'from': 'llama2',
'stream': False,
},
).respond_with_json({'status': 'success'})
client = Client(httpserver.url_for('/'))
response = client.create('dummy', modelfile='FROM llama2')
response = client.create('dummy', from_='llama2')
assert response['status'] == 'success'
@ -704,7 +622,7 @@ def test_client_create_blob(httpserver: HTTPServer):
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = client._create_blob(blob.name)
response = client.create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
@ -714,7 +632,7 @@ def test_client_create_blob_exists(httpserver: HTTPServer):
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = client._create_blob(blob.name)
response = client.create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
@ -1015,142 +933,57 @@ async def test_async_client_push_stream(httpserver: HTTPServer):
@pytest.mark.asyncio
async def test_async_client_create_path(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
async def test_async_client_create_with_blob(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'files': {'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
'stream': False,
},
).respond_with_json({'status': 'success'})
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile() as blob:
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
modelfile.flush()
response = await client.create('dummy', path=modelfile.name)
assert response['status'] == 'success'
@pytest.mark.asyncio
async def test_async_client_create_path_relative(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({'status': 'success'})
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = await client.create('dummy', path=modelfile.name)
assert response['status'] == 'success'
@pytest.mark.asyncio
async def test_async_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({'status': 'success'})
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = await client.create('dummy', path=modelfile.name)
assert response['status'] == 'success'
@pytest.mark.asyncio
async def test_async_client_create_modelfile(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({'status': 'success'})
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = await client.create('dummy', modelfile=f'FROM {blob.name}')
with tempfile.NamedTemporaryFile():
response = await client.create('dummy', files={'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'})
assert response['status'] == 'success'
@pytest.mark.asyncio
async def test_async_client_create_modelfile_roundtrip(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
async def test_async_client_create_with_parameters_roundtrip(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>
{{.Prompt}} [/INST]"""
SYSTEM """
Use
multiline
strings.
"""
PARAMETER stop [INST]
PARAMETER stop [/INST]
PARAMETER stop <<SYS>>
PARAMETER stop <</SYS>>''',
'quantize': 'q4_k_m',
'from': 'mymodel',
'adapters': {'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
'template': '[INST] <<SYS>>{{.System}}<</SYS>>\n{{.Prompt}} [/INST]',
'license': 'this is my license',
'system': '\nUse\nmultiline\nstrings.\n',
'parameters': {'stop': ['[INST]', '[/INST]', '<<SYS>>', '<</SYS>>'], 'pi': 3.14159},
'messages': [{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
'stream': False,
},
).respond_with_json({'status': 'success'})
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
with tempfile.NamedTemporaryFile():
response = await client.create(
'dummy',
modelfile='\n'.join(
[
f'FROM {blob.name}',
'TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>',
'{{.Prompt}} [/INST]"""',
'SYSTEM """',
'Use',
'multiline',
'strings.',
'"""',
'PARAMETER stop [INST]',
'PARAMETER stop [/INST]',
'PARAMETER stop <<SYS>>',
'PARAMETER stop <</SYS>>',
]
),
quantize='q4_k_m',
from_='mymodel',
adapters={'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
template='[INST] <<SYS>>{{.System}}<</SYS>>\n{{.Prompt}} [/INST]',
license='this is my license',
system='\nUse\nmultiline\nstrings.\n',
parameters={'stop': ['[INST]', '[/INST]', '<<SYS>>', '<</SYS>>'], 'pi': 3.14159},
messages=[{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
stream=False,
)
assert response['status'] == 'success'
@ -1162,14 +995,14 @@ async def test_async_client_create_from_library(httpserver: HTTPServer):
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM llama2',
'from': 'llama2',
'stream': False,
},
).respond_with_json({'status': 'success'})
client = AsyncClient(httpserver.url_for('/'))
response = await client.create('dummy', modelfile='FROM llama2')
response = await client.create('dummy', from_='llama2')
assert response['status'] == 'success'
@ -1180,7 +1013,7 @@ async def test_async_client_create_blob(httpserver: HTTPServer):
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = await client._create_blob(blob.name)
response = await client.create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
@ -1191,7 +1024,7 @@ async def test_async_client_create_blob_exists(httpserver: HTTPServer):
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = await client._create_blob(blob.name)
response = await client.create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'

View File

@ -2,7 +2,7 @@ from base64 import b64encode
from pathlib import Path
import pytest
from ollama._types import Image
from ollama._types import CreateRequest, Image
import tempfile
@ -52,3 +52,42 @@ def test_image_serialization_string_path():
with pytest.raises(ValueError):
img = Image(value='not an image')
img.model_dump()
def test_create_request_serialization():
request = CreateRequest(model='test-model', from_='base-model', quantize='q4_0', files={'file1': 'content1'}, adapters={'adapter1': 'content1'}, template='test template', license='MIT', system='test system', parameters={'param1': 'value1'})
serialized = request.model_dump()
assert serialized['from'] == 'base-model'
assert 'from_' not in serialized
assert serialized['quantize'] == 'q4_0'
assert serialized['files'] == {'file1': 'content1'}
assert serialized['adapters'] == {'adapter1': 'content1'}
assert serialized['template'] == 'test template'
assert serialized['license'] == 'MIT'
assert serialized['system'] == 'test system'
assert serialized['parameters'] == {'param1': 'value1'}
def test_create_request_serialization_exclude_none_true():
request = CreateRequest(model='test-model', from_=None, quantize=None)
serialized = request.model_dump(exclude_none=True)
assert serialized == {'model': 'test-model'}
assert 'from' not in serialized
assert 'from_' not in serialized
assert 'quantize' not in serialized
def test_create_request_serialization_exclude_none_false():
request = CreateRequest(model='test-model', from_=None, quantize=None)
serialized = request.model_dump(exclude_none=False)
assert 'from' in serialized
assert 'quantize' in serialized
assert 'adapters' in serialized
assert 'from_' not in serialized
def test_create_request_serialization_license_list():
request = CreateRequest(model='test-model', license=['MIT', 'Apache-2.0'])
serialized = request.model_dump()
assert serialized['license'] == ['MIT', 'Apache-2.0']