Added SubscriptableBaseModel to the Model classes and added codec tests

This commit is contained in:
Parth Sareen 2024-11-06 12:00:35 -08:00
parent 5f4496e22c
commit a69c07db8c
2 changed files with 274 additions and 2 deletions

View File

@ -97,6 +97,7 @@ class BaseGenerateRequest(BaseStreamableRequest):
class Image(BaseModel):
value: Union[FilePath, Base64Str, bytes]
# This overloads the `model_dump` method and returns values depending on the type of the `value` field
@model_serializer
def serialize_model(self):
if isinstance(self.value, Path):
@ -333,7 +334,7 @@ class ModelDetails(SubscriptableBaseModel):
class ListResponse(SubscriptableBaseModel):
class Model(BaseModel):
class Model(SubscriptableBaseModel):
modified_at: Optional[datetime] = None
digest: Optional[str] = None
size: Optional[ByteSize] = None
@ -394,7 +395,7 @@ class ShowResponse(SubscriptableBaseModel):
class ProcessResponse(SubscriptableBaseModel):
class Model(BaseModel):
class Model(SubscriptableBaseModel):
model: Optional[str] = None
name: Optional[str] = None
digest: Optional[str] = None

View File

@ -0,0 +1,271 @@
import pytest
from pathlib import Path
from base64 import b64decode, b64encode
from datetime import datetime
from ollama._types import (
GenerateRequest,
ChatRequest,
Message,
Image,
Tool,
Options,
EmbedRequest,
EmbedResponse,
PullRequest,
PushRequest,
ShowRequest,
ShowResponse,
ProcessResponse,
ListResponse,
RequestError,
ResponseError
)
def test_generate_request_serialization():
req = GenerateRequest(
model="test-model",
prompt="test prompt",
system="test system",
template="test template",
context=[1,2,3],
options={"temperature": 0.7}
)
data = req.model_dump(exclude_none=True)
assert data == {
"model": "test-model",
"prompt": "test prompt",
"system": "test system",
"template": "test template",
"context": [1,2,3],
"options": {"temperature": 0.7}
}
def test_chat_request_serialization():
req = ChatRequest(
model="test-model",
messages=[
Message(role="user", content="Hello"),
Message(role="assistant", content="Hi there")
]
)
data = req.model_dump(exclude_none=True)
assert data == {
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"}
]
}
def test_image_serialization():
# Test bytes serialization
image_bytes = b"test image bytes"
img = Image(value=image_bytes)
assert img.model_dump() == b64encode(image_bytes).decode()
# Test base64 string serialization
b64_str = "dGVzdCBiYXNlNjQgc3RyaW5n"
img = Image(value=b64_str)
print(img.model_dump())
assert img.model_dump() == b64decode(b64_str).decode()
def test_tool_serialization():
tool = Tool(
type="function",
function=Tool.Function(
name="test_function",
description="A test function",
parameters=Tool.Function.Parameters(
type="object",
required=["param1"],
properties={"param1": {"type": "string"}}
)
)
)
data = tool.model_dump(exclude_none=True)
assert data == {
"type": "function",
"function": {
"name": "test_function",
"description": "A test function",
"parameters": {
"type": "object",
"required": ["param1"],
"properties": {"param1": {"type": "string"}}
}
}
}
def test_options_serialization():
options = Options(
num_ctx=4096,
num_gpu=1,
temperature=0.7,
top_p=0.9,
stop=[".", "?", "!"]
)
data = options.model_dump(exclude_none=True)
assert data == {
"num_ctx": 4096,
"num_gpu": 1,
"temperature": 0.7,
"top_p": 0.9,
"stop": [".", "?", "!"]
}
def test_embed_request_response_serialization():
# Test EmbedRequest
req = EmbedRequest(
model="test-model",
input="test input",
truncate=True
)
data = req.model_dump(exclude_none=True)
assert data == {
"model": "test-model",
"input": "test input",
"truncate": True
}
# Test EmbedResponse
resp = EmbedResponse(
embeddings=[[0.1, 0.2, 0.3]],
total_duration=1000
)
data = resp.model_dump(exclude_none=True)
assert data == {
"embeddings": [[0.1, 0.2, 0.3]],
"total_duration": 1000
}
def test_pull_push_request_serialization():
# Test PullRequest
pull_req = PullRequest(
model="test-model",
insecure=True,
stream=True
)
data = pull_req.model_dump(exclude_none=True)
assert data == {
"model": "test-model",
"insecure": True,
"stream": True
}
# Test PushRequest
push_req = PushRequest(
model="test-model",
insecure=True,
stream=True
)
data = push_req.model_dump(exclude_none=True)
assert data == {
"model": "test-model",
"insecure": True,
"stream": True
}
def test_show_request_response_serialization():
# Test ShowRequest
req = ShowRequest(model="test-model")
data = req.model_dump(exclude_none=True)
assert data == {
"model": "test-model"
}
# Test ShowResponse
resp = ShowResponse(
modified_at=datetime(2024, 1, 1),
template="test template",
modelfile="test modelfile",
license="MIT",
details={"format": "gguf", "family": "llama"},
model_info={"version": "1.0"},
parameters="test params"
)
data = resp.model_dump(exclude_none=True)
assert data == {
"modified_at": datetime(2024, 1, 1),
"template": "test template",
"modelfile": "test modelfile",
"license": "MIT",
"details": {"format": "gguf", "family": "llama"},
"modelinfo": {"version": "1.0"},
"parameters": "test params"
}
def test_process_response_serialization():
resp = ProcessResponse(
models=[
ProcessResponse.Model(
model="test-model",
name="test",
digest="sha256:abc",
size=1000,
size_vram=500
)
]
)
data = resp.model_dump(exclude_none=True)
assert data == {
"models": [{
"model": "test-model",
"name": "test",
"digest": "sha256:abc",
"size": 1000,
"size_vram": 500
}]
}
def test_list_response_serialization():
resp = ListResponse(
models=[
ListResponse.Model(
name="test-model",
modified_at=datetime(2024, 1, 1),
digest="sha256:abc",
size=1000,
details={"format": "gguf"}
)
]
)
data = resp.model_dump(exclude_none=True)
assert data == {
"models": [{
"name": "test-model",
"modified_at": datetime(2024, 1, 1),
"digest": "sha256:abc",
"size": 1000,
"details": {"format": "gguf"}
}]
}
def test_error_classes():
# Test RequestError
error = RequestError("test error")
assert error.error == "test error"
assert str(error) == "test error"
# Test ResponseError
error = ResponseError("test error", 404)
assert error.error == "test error"
assert error.status_code == 404
assert str(error) == "test error"
# Test ResponseError with JSON error
error = ResponseError('{"error": "json error"}', 500)
assert error.error == "json error"
assert error.status_code == 500