mirror of
https://github.com/ollama/ollama-python.git
synced 2026-05-03 12:52:35 +00:00
Passing Functions as Tools (#321)
* Functions can now be passed as tools
This commit is contained in:
+55
-1
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
from pydantic import ValidationError
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
@@ -8,7 +9,7 @@ from pytest_httpserver import HTTPServer, URIPattern
|
||||
from werkzeug.wrappers import Request, Response
|
||||
from PIL import Image
|
||||
|
||||
from ollama._client import Client, AsyncClient
|
||||
from ollama._client import Client, AsyncClient, _copy_tools
|
||||
|
||||
|
||||
class PrefixPattern(URIPattern):
|
||||
@@ -982,3 +983,56 @@ def test_headers():
|
||||
)
|
||||
assert client._client.headers['x-custom'] == 'value'
|
||||
assert client._client.headers['content-type'] == 'application/json'
|
||||
|
||||
|
||||
def test_copy_tools():
|
||||
def func1(x: int) -> str:
|
||||
"""Simple function 1.
|
||||
Args:
|
||||
x (integer): A number
|
||||
"""
|
||||
pass
|
||||
|
||||
def func2(y: str) -> int:
|
||||
"""Simple function 2.
|
||||
Args:
|
||||
y (string): A string
|
||||
"""
|
||||
pass
|
||||
|
||||
# Test with list of functions
|
||||
tools = list(_copy_tools([func1, func2]))
|
||||
assert len(tools) == 2
|
||||
assert tools[0].function.name == 'func1'
|
||||
assert tools[1].function.name == 'func2'
|
||||
|
||||
# Test with empty input
|
||||
assert list(_copy_tools()) == []
|
||||
assert list(_copy_tools(None)) == []
|
||||
assert list(_copy_tools([])) == []
|
||||
|
||||
# Test with mix of functions and tool dicts
|
||||
tool_dict = {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'test',
|
||||
'description': 'Test function',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {'x': {'type': 'string', 'description': 'A string'}},
|
||||
'required': ['x'],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tools = list(_copy_tools([func1, tool_dict]))
|
||||
assert len(tools) == 2
|
||||
assert tools[0].function.name == 'func1'
|
||||
assert tools[1].function.name == 'test'
|
||||
|
||||
|
||||
def test_tool_validation():
|
||||
# Raises ValidationError when used as it is a generator
|
||||
with pytest.raises(ValidationError):
|
||||
invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}}
|
||||
list(_copy_tools([invalid_tool]))
|
||||
|
||||
Reference in New Issue
Block a user