diff --git a/ollama/_types.py b/ollama/_types.py index 76a6174..fc4e178 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -401,9 +401,9 @@ class PushRequest(BaseStreamableRequest): class CreateRequest(BaseStreamableRequest): - @model_serializer - def serialize_model(self): - output = {k: v for k, v in self.__dict__.items() if v is not None} + @model_serializer(mode='wrap') + def serialize_model(self, nxt): + output = nxt(self) if 'from_' in output: output['from'] = output.pop('from_') return output diff --git a/tests/test_type_serialization.py b/tests/test_type_serialization.py index 8200ce3..195bac8 100644 --- a/tests/test_type_serialization.py +++ b/tests/test_type_serialization.py @@ -2,7 +2,7 @@ from base64 import b64encode from pathlib import Path import pytest -from ollama._types import Image +from ollama._types import CreateRequest, Image import tempfile @@ -52,3 +52,67 @@ def test_image_serialization_string_path(): with pytest.raises(ValueError): img = Image(value='not an image') img.model_dump() + + + + +def test_create_request_serialization(): + request = CreateRequest( + model="test-model", + from_="base-model", + quantize="q4_0", + files={"file1": "content1"}, + adapters={"adapter1": "content1"}, + template="test template", + license="MIT", + system="test system", + parameters={"param1": "value1"} + ) + + serialized = request.model_dump() + assert serialized["from"] == "base-model" + assert "from_" not in serialized + assert serialized["quantize"] == "q4_0" + assert serialized["files"] == {"file1": "content1"} + assert serialized["adapters"] == {"adapter1": "content1"} + assert serialized["template"] == "test template" + assert serialized["license"] == "MIT" + assert serialized["system"] == "test system" + assert serialized["parameters"] == {"param1": "value1"} + + + +def test_create_request_serialization_exclude_none_true(): + request = CreateRequest( + model="test-model", + from_=None, + quantize=None + ) + serialized = request.model_dump(exclude_none=True) + assert serialized == {"model": "test-model"} + assert "from" not in serialized + assert "from_" not in serialized + assert "quantize" not in serialized + + +def test_create_request_serialization_exclude_none_false(): + request = CreateRequest( + model="test-model", + from_=None, + quantize=None + ) + serialized = request.model_dump(exclude_none=False) + assert "from" in serialized + assert "quantize" in serialized + assert "adapters" in serialized + assert "from_" not in serialized + + +def test_create_request_serialization_license_list(): + request = CreateRequest( + model="test-model", + license=["MIT", "Apache-2.0"] + ) + serialized = request.model_dump() + assert serialized["license"] == ["MIT", "Apache-2.0"] +