mirror of
https://github.com/ollama/ollama-python.git
synced 2026-01-13 21:57:16 +08:00
types: relax type for tools (#550)
This commit is contained in:
parent
dad9e1ca3a
commit
34e98bd237
@ -79,7 +79,7 @@ class SubscriptableBaseModel(BaseModel):
|
||||
if key in self.model_fields_set:
|
||||
return True
|
||||
|
||||
if value := self.model_fields.get(key):
|
||||
if value := self.__class__.model_fields.get(key):
|
||||
return value.default is not None
|
||||
|
||||
return False
|
||||
@ -313,7 +313,7 @@ class Message(SubscriptableBaseModel):
|
||||
|
||||
|
||||
class Tool(SubscriptableBaseModel):
|
||||
type: Optional[Literal['function']] = 'function'
|
||||
type: Optional[str] = 'function'
|
||||
|
||||
class Function(SubscriptableBaseModel):
|
||||
name: Optional[str] = None
|
||||
|
||||
@ -79,11 +79,12 @@ def convert_function_to_tool(func: Callable) -> Tool:
|
||||
}
|
||||
|
||||
tool = Tool(
|
||||
type='function',
|
||||
function=Tool.Function(
|
||||
name=func.__name__,
|
||||
description=schema.get('description', ''),
|
||||
parameters=Tool.Function.Parameters(**schema),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return Tool.model_validate(tool)
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import Any
|
||||
|
||||
import pytest
|
||||
from httpx import Response as httpxResponse
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel
|
||||
from pytest_httpserver import HTTPServer, URIPattern
|
||||
from werkzeug.wrappers import Request, Response
|
||||
|
||||
@ -1136,10 +1136,11 @@ def test_copy_tools():
|
||||
|
||||
|
||||
def test_tool_validation():
|
||||
# Raises ValidationError when used as it is a generator
|
||||
with pytest.raises(ValidationError):
|
||||
invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}}
|
||||
list(_copy_tools([invalid_tool]))
|
||||
arbitrary_tool = {'type': 'custom_type', 'function': {'name': 'test'}}
|
||||
tools = list(_copy_tools([arbitrary_tool]))
|
||||
assert len(tools) == 1
|
||||
assert tools[0].type == 'custom_type'
|
||||
assert tools[0].function.name == 'test'
|
||||
|
||||
|
||||
def test_client_connection_error():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user