Passing Functions as Tools (#321)

* Functions can now be passed as tools
This commit is contained in:
Parth Sareen 2024-11-20 15:49:50 -08:00 committed by GitHub
parent da2893b099
commit 139c89e833
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 545 additions and 39 deletions

View File

@ -10,6 +10,7 @@ from hashlib import sha256
from typing import ( from typing import (
Any, Any,
Callable,
Literal, Literal,
Mapping, Mapping,
Optional, Optional,
@ -22,6 +23,9 @@ from typing import (
import sys import sys
from ollama._utils import convert_function_to_tool
if sys.version_info < (3, 9): if sys.version_info < (3, 9):
from typing import Iterator, AsyncIterator from typing import Iterator, AsyncIterator
else: else:
@ -284,7 +288,7 @@ class Client(BaseClient):
model: str = '', model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*, *,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None, tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: bool = False, stream: bool = False,
format: Optional[Literal['', 'json']] = None, format: Optional[Literal['', 'json']] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None,
@ -293,6 +297,30 @@ class Client(BaseClient):
""" """
Create a chat response using the requested model. Create a chat response using the requested model.
Args:
tools:
A JSON schema as a dict, an Ollama Tool or a Python Function.
Python functions need to follow Google style docstrings to be converted to an Ollama Tool.
For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
stream: Whether to stream the response.
format: The format of the response.
Example:
def add_two_numbers(a: int, b: int) -> int:
'''
Add two numbers together.
Args:
a: First number to add
b: Second number to add
Returns:
int: The sum of a and b
'''
return a + b
client.chat(model='llama3.1:8b', tools=[add_two_numbers], messages=[...])
Raises `RequestError` if a model is not provided. Raises `RequestError` if a model is not provided.
Raises `ResponseError` if the request could not be fulfilled. Raises `ResponseError` if the request could not be fulfilled.
@ -750,7 +778,7 @@ class AsyncClient(BaseClient):
model: str = '', model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*, *,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None, tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[True] = True, stream: Literal[True] = True,
format: Optional[Literal['', 'json']] = None, format: Optional[Literal['', 'json']] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None,
@ -771,6 +799,30 @@ class AsyncClient(BaseClient):
""" """
Create a chat response using the requested model. Create a chat response using the requested model.
Args:
tools:
A JSON schema as a dict, an Ollama Tool or a Python Function.
Python functions need to follow Google style docstrings to be converted to an Ollama Tool.
For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
stream: Whether to stream the response.
format: The format of the response.
Example:
def add_two_numbers(a: int, b: int) -> int:
'''
Add two numbers together.
Args:
a: First number to add
b: Second number to add
Returns:
int: The sum of a and b
'''
return a + b
await client.chat(model='llama3.1:8b', tools=[add_two_numbers], messages=[...])
Raises `RequestError` if a model is not provided. Raises `RequestError` if a model is not provided.
Raises `ResponseError` if the request could not be fulfilled. Raises `ResponseError` if the request could not be fulfilled.
@ -1075,9 +1127,9 @@ def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]
) )
def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]]) -> Iterator[Tool]: def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None) -> Iterator[Tool]:
for tool in tools or []: for unprocessed_tool in tools or []:
yield Tool.model_validate(tool) yield convert_function_to_tool(unprocessed_tool) if callable(unprocessed_tool) else Tool.model_validate(unprocessed_tool)
def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]: def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:

View File

@ -1,26 +1,18 @@
import json import json
from base64 import b64encode from base64 import b64decode, b64encode
from pathlib import Path from pathlib import Path
from datetime import datetime from datetime import datetime
from typing import ( from typing import Any, Mapping, Optional, Union, Sequence
Any,
Literal, from typing_extensions import Annotated, Literal
Mapping,
Optional,
Sequence,
Union,
)
from typing_extensions import Annotated
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
ByteSize, ByteSize,
ConfigDict,
Field, Field,
FilePath,
Base64Str,
model_serializer, model_serializer,
) )
from pydantic.json_schema import JsonSchemaValue
class SubscriptableBaseModel(BaseModel): class SubscriptableBaseModel(BaseModel):
@ -95,16 +87,26 @@ class BaseGenerateRequest(BaseStreamableRequest):
class Image(BaseModel): class Image(BaseModel):
value: Union[FilePath, Base64Str, bytes] value: Union[str, bytes, Path]
# This overloads the `model_dump` method and returns values depending on the type of the `value` field
@model_serializer @model_serializer
def serialize_model(self): def serialize_model(self):
if isinstance(self.value, Path): if isinstance(self.value, (Path, bytes)):
return b64encode(self.value.read_bytes()).decode() return b64encode(self.value.read_bytes() if isinstance(self.value, Path) else self.value).decode()
elif isinstance(self.value, bytes):
return b64encode(self.value).decode() if isinstance(self.value, str):
return self.value if Path(self.value).exists():
return b64encode(Path(self.value).read_bytes()).decode()
if self.value.split('.')[-1] in ('png', 'jpg', 'jpeg', 'webp'):
raise ValueError(f'File {self.value} does not exist')
try:
# Try to decode to check if it's already base64
b64decode(self.value)
return self.value
except Exception:
raise ValueError('Invalid image data, expected base64 string or path to image file') from Exception
class GenerateRequest(BaseGenerateRequest): class GenerateRequest(BaseGenerateRequest):
@ -222,20 +224,27 @@ class Message(SubscriptableBaseModel):
class Tool(SubscriptableBaseModel): class Tool(SubscriptableBaseModel):
type: Literal['function'] = 'function' type: Optional[Literal['function']] = 'function'
class Function(SubscriptableBaseModel): class Function(SubscriptableBaseModel):
name: str name: Optional[str] = None
description: str description: Optional[str] = None
class Parameters(SubscriptableBaseModel): class Parameters(SubscriptableBaseModel):
type: str type: Optional[Literal['object']] = 'object'
required: Optional[Sequence[str]] = None required: Optional[Sequence[str]] = None
properties: Optional[JsonSchemaValue] = None
parameters: Parameters class Property(SubscriptableBaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
function: Function type: Optional[str] = None
description: Optional[str] = None
properties: Optional[Mapping[str, Property]] = None
parameters: Optional[Parameters] = None
function: Optional[Function] = None
class ChatRequest(BaseGenerateRequest): class ChatRequest(BaseGenerateRequest):
@ -335,6 +344,7 @@ class ModelDetails(SubscriptableBaseModel):
class ListResponse(SubscriptableBaseModel): class ListResponse(SubscriptableBaseModel):
class Model(SubscriptableBaseModel): class Model(SubscriptableBaseModel):
model: Optional[str] = None
modified_at: Optional[datetime] = None modified_at: Optional[datetime] = None
digest: Optional[str] = None digest: Optional[str] = None
size: Optional[ByteSize] = None size: Optional[ByteSize] = None

87
ollama/_utils.py Normal file
View File

@ -0,0 +1,87 @@
from __future__ import annotations
from collections import defaultdict
import inspect
from typing import Callable, Union
import re
import pydantic
from ollama._types import Tool
def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
parsed_docstring = defaultdict(str)
if not doc_string:
return parsed_docstring
key = hash(doc_string)
for line in doc_string.splitlines():
lowered_line = line.lower().strip()
if lowered_line.startswith('args:'):
key = 'args'
elif lowered_line.startswith('returns:') or lowered_line.startswith('yields:') or lowered_line.startswith('raises:'):
key = '_'
else:
# maybe change to a list and join later
parsed_docstring[key] += f'{line.strip()}\n'
last_key = None
for line in parsed_docstring['args'].splitlines():
line = line.strip()
if ':' in line:
# Split the line on either:
# 1. A parenthetical expression like (integer) - captured in group 1
# 2. A colon :
# Followed by optional whitespace. Only split on first occurrence.
parts = re.split(r'(?:\(([^)]*)\)|:)\s*', line, maxsplit=1)
arg_name = parts[0].strip()
last_key = arg_name
# Get the description - will be in parts[1] if parenthetical or parts[-1] if after colon
arg_description = parts[-1].strip()
if len(parts) > 2 and parts[1]: # Has parenthetical content
arg_description = parts[-1].split(':', 1)[-1].strip()
parsed_docstring[last_key] = arg_description
elif last_key and line:
parsed_docstring[last_key] += ' ' + line
return parsed_docstring
def convert_function_to_tool(func: Callable) -> Tool:
doc_string_hash = hash(inspect.getdoc(func))
parsed_docstring = _parse_docstring(inspect.getdoc(func))
schema = type(
func.__name__,
(pydantic.BaseModel,),
{
'__annotations__': {k: v.annotation if v.annotation != inspect._empty else str for k, v in inspect.signature(func).parameters.items()},
'__signature__': inspect.signature(func),
'__doc__': parsed_docstring[doc_string_hash],
},
).model_json_schema()
for k, v in schema.get('properties', {}).items():
# If type is missing, the default is string
types = {t.get('type', 'string') for t in v.get('anyOf')} if 'anyOf' in v else {v.get('type', 'string')}
if 'null' in types:
schema['required'].remove(k)
types.discard('null')
schema['properties'][k] = {
'description': parsed_docstring[k],
'type': ', '.join(types),
}
tool = Tool(
function=Tool.Function(
name=func.__name__,
description=schema.get('description', ''),
parameters=Tool.Function.Parameters(**schema),
)
)
return Tool.model_validate(tool)

View File

@ -1,6 +1,7 @@
import os import os
import io import io
import json import json
from pydantic import ValidationError
import pytest import pytest
import tempfile import tempfile
from pathlib import Path from pathlib import Path
@ -8,7 +9,7 @@ from pytest_httpserver import HTTPServer, URIPattern
from werkzeug.wrappers import Request, Response from werkzeug.wrappers import Request, Response
from PIL import Image from PIL import Image
from ollama._client import Client, AsyncClient from ollama._client import Client, AsyncClient, _copy_tools
class PrefixPattern(URIPattern): class PrefixPattern(URIPattern):
@ -982,3 +983,56 @@ def test_headers():
) )
assert client._client.headers['x-custom'] == 'value' assert client._client.headers['x-custom'] == 'value'
assert client._client.headers['content-type'] == 'application/json' assert client._client.headers['content-type'] == 'application/json'
def test_copy_tools():
def func1(x: int) -> str:
"""Simple function 1.
Args:
x (integer): A number
"""
pass
def func2(y: str) -> int:
"""Simple function 2.
Args:
y (string): A string
"""
pass
# Test with list of functions
tools = list(_copy_tools([func1, func2]))
assert len(tools) == 2
assert tools[0].function.name == 'func1'
assert tools[1].function.name == 'func2'
# Test with empty input
assert list(_copy_tools()) == []
assert list(_copy_tools(None)) == []
assert list(_copy_tools([])) == []
# Test with mix of functions and tool dicts
tool_dict = {
'type': 'function',
'function': {
'name': 'test',
'description': 'Test function',
'parameters': {
'type': 'object',
'properties': {'x': {'type': 'string', 'description': 'A string'}},
'required': ['x'],
},
},
}
tools = list(_copy_tools([func1, tool_dict]))
assert len(tools) == 2
assert tools[0].function.name == 'func1'
assert tools[1].function.name == 'test'
def test_tool_validation():
# Raises ValidationError when used as it is a generator
with pytest.raises(ValidationError):
invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}}
list(_copy_tools([invalid_tool]))

View File

@ -1,15 +1,48 @@
from base64 import b64decode, b64encode from base64 import b64encode
from pathlib import Path
import pytest
from ollama._types import Image from ollama._types import Image
import tempfile
def test_image_serialization(): def test_image_serialization_bytes():
# Test bytes serialization
image_bytes = b'test image bytes' image_bytes = b'test image bytes'
encoded_string = b64encode(image_bytes).decode()
img = Image(value=image_bytes) img = Image(value=image_bytes)
assert img.model_dump() == b64encode(image_bytes).decode() assert img.model_dump() == encoded_string
# Test base64 string serialization
def test_image_serialization_base64_string():
b64_str = 'dGVzdCBiYXNlNjQgc3RyaW5n' b64_str = 'dGVzdCBiYXNlNjQgc3RyaW5n'
img = Image(value=b64_str) img = Image(value=b64_str)
assert img.model_dump() == b64decode(b64_str).decode() assert img.model_dump() == b64_str # Should return as-is if valid base64
def test_image_serialization_plain_string():
img = Image(value='not a path or base64')
assert img.model_dump() == 'not a path or base64' # Should return as-is
def test_image_serialization_path():
with tempfile.NamedTemporaryFile() as temp_file:
temp_file.write(b'test file content')
temp_file.flush()
img = Image(value=Path(temp_file.name))
assert img.model_dump() == b64encode(b'test file content').decode()
def test_image_serialization_string_path():
with tempfile.NamedTemporaryFile() as temp_file:
temp_file.write(b'test file content')
temp_file.flush()
img = Image(value=temp_file.name)
assert img.model_dump() == b64encode(b'test file content').decode()
with pytest.raises(ValueError):
img = Image(value='some_path/that/does/not/exist.png')
img.model_dump()
with pytest.raises(ValueError):
img = Image(value='not an image')
img.model_dump()

270
tests/test_utils.py Normal file
View File

@ -0,0 +1,270 @@
import json
import sys
from typing import Dict, List, Mapping, Sequence, Set, Tuple, Union
from ollama._utils import convert_function_to_tool
def test_function_to_tool_conversion():
def add_numbers(x: int, y: Union[int, None] = None) -> int:
"""Add two numbers together.
args:
x (integer): The first number
y (integer, optional): The second number
Returns:
integer: The sum of x and y
"""
return x + y
tool = convert_function_to_tool(add_numbers).model_dump()
assert tool['type'] == 'function'
assert tool['function']['name'] == 'add_numbers'
assert tool['function']['description'] == 'Add two numbers together.'
assert tool['function']['parameters']['type'] == 'object'
assert tool['function']['parameters']['properties']['x']['type'] == 'integer'
assert tool['function']['parameters']['properties']['x']['description'] == 'The first number'
assert tool['function']['parameters']['required'] == ['x']
def test_function_with_no_args():
def simple_func():
"""
A simple function with no arguments.
Args:
None
Returns:
None
"""
pass
tool = convert_function_to_tool(simple_func).model_dump()
assert tool['function']['name'] == 'simple_func'
assert tool['function']['description'] == 'A simple function with no arguments.'
assert tool['function']['parameters']['properties'] == {}
def test_function_with_all_types():
if sys.version_info >= (3, 10):
def all_types(
x: int,
y: str,
z: list[int],
w: dict[str, int],
v: int | str | None,
) -> int | dict[str, int] | str | list[int] | None:
"""
A function with all types.
Args:
x (integer): The first number
y (string): The second number
z (array): The third number
w (object): The fourth number
v (integer | string | None): The fifth number
"""
pass
else:
def all_types(
x: int,
y: str,
z: Sequence,
w: Mapping[str, int],
d: Dict[str, int],
s: Set[int],
t: Tuple[int, str],
l: List[int], # noqa: E741
o: Union[int, None],
) -> Union[Mapping[str, int], str, None]:
"""
A function with all types.
Args:
x (integer): The first number
y (string): The second number
z (array): The third number
w (object): The fourth number
d (object): The fifth number
s (array): The sixth number
t (array): The seventh number
l (array): The eighth number
o (integer | None): The ninth number
"""
pass
tool_json = convert_function_to_tool(all_types).model_dump_json()
tool = json.loads(tool_json)
assert tool['function']['parameters']['properties']['x']['type'] == 'integer'
assert tool['function']['parameters']['properties']['y']['type'] == 'string'
if sys.version_info >= (3, 10):
assert tool['function']['parameters']['properties']['z']['type'] == 'array'
assert tool['function']['parameters']['properties']['w']['type'] == 'object'
assert set(x.strip().strip("'") for x in tool['function']['parameters']['properties']['v']['type'].removeprefix('[').removesuffix(']').split(',')) == {'string', 'integer'}
assert tool['function']['parameters']['properties']['v']['type'] != 'null'
assert tool['function']['parameters']['required'] == ['x', 'y', 'z', 'w']
else:
assert tool['function']['parameters']['properties']['z']['type'] == 'array'
assert tool['function']['parameters']['properties']['w']['type'] == 'object'
assert tool['function']['parameters']['properties']['d']['type'] == 'object'
assert tool['function']['parameters']['properties']['s']['type'] == 'array'
assert tool['function']['parameters']['properties']['t']['type'] == 'array'
assert tool['function']['parameters']['properties']['l']['type'] == 'array'
assert tool['function']['parameters']['properties']['o']['type'] == 'integer'
assert tool['function']['parameters']['properties']['o']['type'] != 'null'
assert tool['function']['parameters']['required'] == ['x', 'y', 'z', 'w', 'd', 's', 't', 'l']
def test_function_docstring_parsing():
from typing import List, Dict, Any
def func_with_complex_docs(x: int, y: List[str]) -> Dict[str, Any]:
"""
Test function with complex docstring.
Args:
x (integer): A number
with multiple lines
y (array of string): A list
with multiple lines
Returns:
object: A dictionary
with multiple lines
"""
pass
tool = convert_function_to_tool(func_with_complex_docs).model_dump()
assert tool['function']['description'] == 'Test function with complex docstring.'
assert tool['function']['parameters']['properties']['x']['description'] == 'A number with multiple lines'
assert tool['function']['parameters']['properties']['y']['description'] == 'A list with multiple lines'
def test_skewed_docstring_parsing():
def add_two_numbers(x: int, y: int) -> int:
"""
Add two numbers together.
Args:
x (integer): : The first number
y (integer ): The second number
Returns:
integer: The sum of x and y
"""
pass
tool = convert_function_to_tool(add_two_numbers).model_dump()
assert tool['function']['parameters']['properties']['x']['description'] == ': The first number'
assert tool['function']['parameters']['properties']['y']['description'] == 'The second number'
def test_function_with_no_docstring():
def no_docstring():
pass
def no_docstring_with_args(x: int, y: int):
pass
tool = convert_function_to_tool(no_docstring).model_dump()
assert tool['function']['description'] == ''
tool = convert_function_to_tool(no_docstring_with_args).model_dump()
assert tool['function']['description'] == ''
assert tool['function']['parameters']['properties']['x']['description'] == ''
assert tool['function']['parameters']['properties']['y']['description'] == ''
def test_function_with_only_description():
def only_description():
"""
A function with only a description.
"""
pass
tool = convert_function_to_tool(only_description).model_dump()
assert tool['function']['description'] == 'A function with only a description.'
assert tool['function']['parameters'] == {'type': 'object', 'properties': {}, 'required': None}
def only_description_with_args(x: int, y: int):
"""
A function with only a description.
"""
pass
tool = convert_function_to_tool(only_description_with_args).model_dump()
assert tool['function']['description'] == 'A function with only a description.'
assert tool['function']['parameters'] == {
'type': 'object',
'properties': {
'x': {'type': 'integer', 'description': ''},
'y': {'type': 'integer', 'description': ''},
},
'required': ['x', 'y'],
}
def test_function_with_yields():
def function_with_yields(x: int, y: int):
"""
A function with yields section.
Args:
x: the first number
y: the second number
Yields:
The sum of x and y
"""
pass
tool = convert_function_to_tool(function_with_yields).model_dump()
assert tool['function']['description'] == 'A function with yields section.'
assert tool['function']['parameters']['properties']['x']['description'] == 'the first number'
assert tool['function']['parameters']['properties']['y']['description'] == 'the second number'
def test_function_with_no_types():
def no_types(a, b):
"""
A function with no types.
"""
pass
tool = convert_function_to_tool(no_types).model_dump()
assert tool['function']['parameters']['properties']['a']['type'] == 'string'
assert tool['function']['parameters']['properties']['b']['type'] == 'string'
def test_function_with_parentheses():
def func_with_parentheses(a: int, b: int) -> int:
"""
A function with parentheses.
Args:
a: First (:thing) number to add
b: Second number to add
Returns:
int: The sum of a and b
"""
pass
def func_with_parentheses_and_args(a: int, b: int):
"""
A function with parentheses and args.
Args:
a(integer) : First (:thing) number to add
b(integer) :Second number to add
"""
pass
tool = convert_function_to_tool(func_with_parentheses).model_dump()
assert tool['function']['parameters']['properties']['a']['description'] == 'First (:thing) number to add'
assert tool['function']['parameters']['properties']['b']['description'] == 'Second number to add'
tool = convert_function_to_tool(func_with_parentheses_and_args).model_dump()
assert tool['function']['parameters']['properties']['a']['description'] == 'First (:thing) number to add'
assert tool['function']['parameters']['properties']['b']['description'] == 'Second number to add'