From f25834217be31fb730c9c8d64a7a2d638d489bf5 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Wed, 6 Nov 2024 14:04:56 -0800 Subject: [PATCH] Pydantic Fixes and Tests (#311) * Added SubscriptableBaseModel to the Model classes and added Image codec test --------- Co-authored-by: Parth Sareen --- ollama/_types.py | 5 +++-- tests/test_type_serialization.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) create mode 100644 tests/test_type_serialization.py diff --git a/ollama/_types.py b/ollama/_types.py index b223d9c..968099d 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -97,6 +97,7 @@ class BaseGenerateRequest(BaseStreamableRequest): class Image(BaseModel): value: Union[FilePath, Base64Str, bytes] + # This overloads the `model_dump` method and returns values depending on the type of the `value` field @model_serializer def serialize_model(self): if isinstance(self.value, Path): @@ -333,7 +334,7 @@ class ModelDetails(SubscriptableBaseModel): class ListResponse(SubscriptableBaseModel): - class Model(BaseModel): + class Model(SubscriptableBaseModel): modified_at: Optional[datetime] = None digest: Optional[str] = None size: Optional[ByteSize] = None @@ -394,7 +395,7 @@ class ShowResponse(SubscriptableBaseModel): class ProcessResponse(SubscriptableBaseModel): - class Model(BaseModel): + class Model(SubscriptableBaseModel): model: Optional[str] = None name: Optional[str] = None digest: Optional[str] = None diff --git a/tests/test_type_serialization.py b/tests/test_type_serialization.py new file mode 100644 index 0000000..f127b03 --- /dev/null +++ b/tests/test_type_serialization.py @@ -0,0 +1,15 @@ +from base64 import b64decode, b64encode + +from ollama._types import Image + + +def test_image_serialization(): + # Test bytes serialization + image_bytes = b'test image bytes' + img = Image(value=image_bytes) + assert img.model_dump() == b64encode(image_bytes).decode() + + # Test base64 string serialization + b64_str = 'dGVzdCBiYXNlNjQgc3RyaW5n' + img = Image(value=b64_str) + assert img.model_dump() == b64decode(b64_str).decode()