diff --git a/ollama/_types.py b/ollama/_types.py index b223d9c..968099d 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -97,6 +97,7 @@ class BaseGenerateRequest(BaseStreamableRequest): class Image(BaseModel): value: Union[FilePath, Base64Str, bytes] + # This overloads the `model_dump` method and returns values depending on the type of the `value` field @model_serializer def serialize_model(self): if isinstance(self.value, Path): @@ -333,7 +334,7 @@ class ModelDetails(SubscriptableBaseModel): class ListResponse(SubscriptableBaseModel): - class Model(BaseModel): + class Model(SubscriptableBaseModel): modified_at: Optional[datetime] = None digest: Optional[str] = None size: Optional[ByteSize] = None @@ -394,7 +395,7 @@ class ShowResponse(SubscriptableBaseModel): class ProcessResponse(SubscriptableBaseModel): - class Model(BaseModel): + class Model(SubscriptableBaseModel): model: Optional[str] = None name: Optional[str] = None digest: Optional[str] = None diff --git a/tests/test_type_serialization.py b/tests/test_type_serialization.py new file mode 100644 index 0000000..f127b03 --- /dev/null +++ b/tests/test_type_serialization.py @@ -0,0 +1,15 @@ +from base64 import b64decode, b64encode + +from ollama._types import Image + + +def test_image_serialization(): + # Test bytes serialization + image_bytes = b'test image bytes' + img = Image(value=image_bytes) + assert img.model_dump() == b64encode(image_bytes).decode() + + # Test base64 string serialization + b64_str = 'dGVzdCBiYXNlNjQgc3RyaW5n' + img = Image(value=b64_str) + assert img.model_dump() == b64decode(b64_str).decode()