From 34e98bd237db8ccb0a9515515e3e83efd2528270 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Tue, 5 Aug 2025 15:59:56 -0700 Subject: [PATCH] types: relax type for tools (#550) --- ollama/_types.py | 4 ++-- ollama/_utils.py | 3 ++- tests/test_client.py | 11 ++++++----- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/ollama/_types.py b/ollama/_types.py index caf1e70..db928e5 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -79,7 +79,7 @@ class SubscriptableBaseModel(BaseModel): if key in self.model_fields_set: return True - if value := self.model_fields.get(key): + if value := self.__class__.model_fields.get(key): return value.default is not None return False @@ -313,7 +313,7 @@ class Message(SubscriptableBaseModel): class Tool(SubscriptableBaseModel): - type: Optional[Literal['function']] = 'function' + type: Optional[str] = 'function' class Function(SubscriptableBaseModel): name: Optional[str] = None diff --git a/ollama/_utils.py b/ollama/_utils.py index 653a04c..15f1cc0 100644 --- a/ollama/_utils.py +++ b/ollama/_utils.py @@ -79,11 +79,12 @@ def convert_function_to_tool(func: Callable) -> Tool: } tool = Tool( + type='function', function=Tool.Function( name=func.__name__, description=schema.get('description', ''), parameters=Tool.Function.Parameters(**schema), - ) + ), ) return Tool.model_validate(tool) diff --git a/tests/test_client.py b/tests/test_client.py index 1e66184..6917edc 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,7 +8,7 @@ from typing import Any import pytest from httpx import Response as httpxResponse -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel from pytest_httpserver import HTTPServer, URIPattern from werkzeug.wrappers import Request, Response @@ -1136,10 +1136,11 @@ def test_copy_tools(): def test_tool_validation(): - # Raises ValidationError when used as it is a generator - with pytest.raises(ValidationError): - invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}} - list(_copy_tools([invalid_tool])) + arbitrary_tool = {'type': 'custom_type', 'function': {'name': 'test'}} + tools = list(_copy_tools([arbitrary_tool])) + assert len(tools) == 1 + assert tools[0].type == 'custom_type' + assert tools[0].function.name == 'test' def test_client_connection_error():