types: relax type for tools (#550)
Some checks failed
test / test (push) Has been cancelled
test / lint (push) Has been cancelled

This commit is contained in:
Parth Sareen 2025-08-05 15:59:56 -07:00 committed by GitHub
parent dad9e1ca3a
commit 34e98bd237
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 8 deletions

View File

@ -79,7 +79,7 @@ class SubscriptableBaseModel(BaseModel):
if key in self.model_fields_set:
return True
if value := self.model_fields.get(key):
if value := self.__class__.model_fields.get(key):
return value.default is not None
return False
@ -313,7 +313,7 @@ class Message(SubscriptableBaseModel):
class Tool(SubscriptableBaseModel):
type: Optional[Literal['function']] = 'function'
type: Optional[str] = 'function'
class Function(SubscriptableBaseModel):
name: Optional[str] = None

View File

@ -79,11 +79,12 @@ def convert_function_to_tool(func: Callable) -> Tool:
}
tool = Tool(
type='function',
function=Tool.Function(
name=func.__name__,
description=schema.get('description', ''),
parameters=Tool.Function.Parameters(**schema),
)
),
)
return Tool.model_validate(tool)

View File

@ -8,7 +8,7 @@ from typing import Any
import pytest
from httpx import Response as httpxResponse
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
from pytest_httpserver import HTTPServer, URIPattern
from werkzeug.wrappers import Request, Response
@ -1136,10 +1136,11 @@ def test_copy_tools():
def test_tool_validation():
# Raises ValidationError when used as it is a generator
with pytest.raises(ValidationError):
invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}}
list(_copy_tools([invalid_tool]))
arbitrary_tool = {'type': 'custom_type', 'function': {'name': 'test'}}
tools = list(_copy_tools([arbitrary_tool]))
assert len(tools) == 1
assert tools[0].type == 'custom_type'
assert tools[0].function.name == 'test'
def test_client_connection_error():