Add custom serializer for CreateRequest + tests

This commit is contained in:
ParthSareen 2025-01-13 11:02:00 -08:00
parent bee11029f7
commit 4f9fb88137
2 changed files with 68 additions and 4 deletions

View File

@ -401,9 +401,9 @@ class PushRequest(BaseStreamableRequest):
class CreateRequest(BaseStreamableRequest):
@model_serializer
def serialize_model(self):
output = {k: v for k, v in self.__dict__.items() if v is not None}
@model_serializer(mode='wrap')
def serialize_model(self, nxt):
output = nxt(self)
if 'from_' in output:
output['from'] = output.pop('from_')
return output

View File

@ -2,7 +2,7 @@ from base64 import b64encode
from pathlib import Path
import pytest
from ollama._types import Image
from ollama._types import CreateRequest, Image
import tempfile
@ -52,3 +52,67 @@ def test_image_serialization_string_path():
with pytest.raises(ValueError):
img = Image(value='not an image')
img.model_dump()
def test_create_request_serialization():
request = CreateRequest(
model="test-model",
from_="base-model",
quantize="q4_0",
files={"file1": "content1"},
adapters={"adapter1": "content1"},
template="test template",
license="MIT",
system="test system",
parameters={"param1": "value1"}
)
serialized = request.model_dump()
assert serialized["from"] == "base-model"
assert "from_" not in serialized
assert serialized["quantize"] == "q4_0"
assert serialized["files"] == {"file1": "content1"}
assert serialized["adapters"] == {"adapter1": "content1"}
assert serialized["template"] == "test template"
assert serialized["license"] == "MIT"
assert serialized["system"] == "test system"
assert serialized["parameters"] == {"param1": "value1"}
def test_create_request_serialization_exclude_none_true():
request = CreateRequest(
model="test-model",
from_=None,
quantize=None
)
serialized = request.model_dump(exclude_none=True)
assert serialized == {"model": "test-model"}
assert "from" not in serialized
assert "from_" not in serialized
assert "quantize" not in serialized
def test_create_request_serialization_exclude_none_false():
request = CreateRequest(
model="test-model",
from_=None,
quantize=None
)
serialized = request.model_dump(exclude_none=False)
assert "from" in serialized
assert "quantize" in serialized
assert "adapters" in serialized
assert "from_" not in serialized
def test_create_request_serialization_license_list():
request = CreateRequest(
model="test-model",
license=["MIT", "Apache-2.0"]
)
serialized = request.model_dump()
assert serialized["license"] == ["MIT", "Apache-2.0"]