mirror of
https://github.com/ollama/ollama-python.git
synced 2026-01-14 06:07:17 +08:00
Add custom serializer for CreateRequest + tests
This commit is contained in:
parent
bee11029f7
commit
4f9fb88137
@ -401,9 +401,9 @@ class PushRequest(BaseStreamableRequest):
|
|||||||
|
|
||||||
|
|
||||||
class CreateRequest(BaseStreamableRequest):
|
class CreateRequest(BaseStreamableRequest):
|
||||||
@model_serializer
|
@model_serializer(mode='wrap')
|
||||||
def serialize_model(self):
|
def serialize_model(self, nxt):
|
||||||
output = {k: v for k, v in self.__dict__.items() if v is not None}
|
output = nxt(self)
|
||||||
if 'from_' in output:
|
if 'from_' in output:
|
||||||
output['from'] = output.pop('from_')
|
output['from'] = output.pop('from_')
|
||||||
return output
|
return output
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from base64 import b64encode
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from ollama._types import Image
|
from ollama._types import CreateRequest, Image
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
|
|
||||||
@ -52,3 +52,67 @@ def test_image_serialization_string_path():
|
|||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
img = Image(value='not an image')
|
img = Image(value='not an image')
|
||||||
img.model_dump()
|
img.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_request_serialization():
|
||||||
|
request = CreateRequest(
|
||||||
|
model="test-model",
|
||||||
|
from_="base-model",
|
||||||
|
quantize="q4_0",
|
||||||
|
files={"file1": "content1"},
|
||||||
|
adapters={"adapter1": "content1"},
|
||||||
|
template="test template",
|
||||||
|
license="MIT",
|
||||||
|
system="test system",
|
||||||
|
parameters={"param1": "value1"}
|
||||||
|
)
|
||||||
|
|
||||||
|
serialized = request.model_dump()
|
||||||
|
assert serialized["from"] == "base-model"
|
||||||
|
assert "from_" not in serialized
|
||||||
|
assert serialized["quantize"] == "q4_0"
|
||||||
|
assert serialized["files"] == {"file1": "content1"}
|
||||||
|
assert serialized["adapters"] == {"adapter1": "content1"}
|
||||||
|
assert serialized["template"] == "test template"
|
||||||
|
assert serialized["license"] == "MIT"
|
||||||
|
assert serialized["system"] == "test system"
|
||||||
|
assert serialized["parameters"] == {"param1": "value1"}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_request_serialization_exclude_none_true():
|
||||||
|
request = CreateRequest(
|
||||||
|
model="test-model",
|
||||||
|
from_=None,
|
||||||
|
quantize=None
|
||||||
|
)
|
||||||
|
serialized = request.model_dump(exclude_none=True)
|
||||||
|
assert serialized == {"model": "test-model"}
|
||||||
|
assert "from" not in serialized
|
||||||
|
assert "from_" not in serialized
|
||||||
|
assert "quantize" not in serialized
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_request_serialization_exclude_none_false():
|
||||||
|
request = CreateRequest(
|
||||||
|
model="test-model",
|
||||||
|
from_=None,
|
||||||
|
quantize=None
|
||||||
|
)
|
||||||
|
serialized = request.model_dump(exclude_none=False)
|
||||||
|
assert "from" in serialized
|
||||||
|
assert "quantize" in serialized
|
||||||
|
assert "adapters" in serialized
|
||||||
|
assert "from_" not in serialized
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_request_serialization_license_list():
|
||||||
|
request = CreateRequest(
|
||||||
|
model="test-model",
|
||||||
|
license=["MIT", "Apache-2.0"]
|
||||||
|
)
|
||||||
|
serialized = request.model_dump()
|
||||||
|
assert serialized["license"] == ["MIT", "Apache-2.0"]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user