mirror of
https://github.com/ollama/ollama-python.git
synced 2026-05-01 11:48:17 +08:00
Passing Functions as Tools (#321)
* Functions can now be passed as tools
This commit is contained in:
parent
da2893b099
commit
139c89e833
@ -10,6 +10,7 @@ from hashlib import sha256
|
|||||||
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
Callable,
|
||||||
Literal,
|
Literal,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
@ -22,6 +23,9 @@ from typing import (
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
from ollama._utils import convert_function_to_tool
|
||||||
|
|
||||||
if sys.version_info < (3, 9):
|
if sys.version_info < (3, 9):
|
||||||
from typing import Iterator, AsyncIterator
|
from typing import Iterator, AsyncIterator
|
||||||
else:
|
else:
|
||||||
@ -284,7 +288,7 @@ class Client(BaseClient):
|
|||||||
model: str = '',
|
model: str = '',
|
||||||
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
|
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
|
||||||
*,
|
*,
|
||||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
|
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
format: Optional[Literal['', 'json']] = None,
|
format: Optional[Literal['', 'json']] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
@ -293,6 +297,30 @@ class Client(BaseClient):
|
|||||||
"""
|
"""
|
||||||
Create a chat response using the requested model.
|
Create a chat response using the requested model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools:
|
||||||
|
A JSON schema as a dict, an Ollama Tool or a Python Function.
|
||||||
|
Python functions need to follow Google style docstrings to be converted to an Ollama Tool.
|
||||||
|
For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
|
||||||
|
stream: Whether to stream the response.
|
||||||
|
format: The format of the response.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
def add_two_numbers(a: int, b: int) -> int:
|
||||||
|
'''
|
||||||
|
Add two numbers together.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a: First number to add
|
||||||
|
b: Second number to add
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The sum of a and b
|
||||||
|
'''
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
client.chat(model='llama3.1:8b', tools=[add_two_numbers], messages=[...])
|
||||||
|
|
||||||
Raises `RequestError` if a model is not provided.
|
Raises `RequestError` if a model is not provided.
|
||||||
|
|
||||||
Raises `ResponseError` if the request could not be fulfilled.
|
Raises `ResponseError` if the request could not be fulfilled.
|
||||||
@ -750,7 +778,7 @@ class AsyncClient(BaseClient):
|
|||||||
model: str = '',
|
model: str = '',
|
||||||
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
|
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
|
||||||
*,
|
*,
|
||||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
|
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
|
||||||
stream: Literal[True] = True,
|
stream: Literal[True] = True,
|
||||||
format: Optional[Literal['', 'json']] = None,
|
format: Optional[Literal['', 'json']] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
@ -771,6 +799,30 @@ class AsyncClient(BaseClient):
|
|||||||
"""
|
"""
|
||||||
Create a chat response using the requested model.
|
Create a chat response using the requested model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools:
|
||||||
|
A JSON schema as a dict, an Ollama Tool or a Python Function.
|
||||||
|
Python functions need to follow Google style docstrings to be converted to an Ollama Tool.
|
||||||
|
For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
|
||||||
|
stream: Whether to stream the response.
|
||||||
|
format: The format of the response.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
def add_two_numbers(a: int, b: int) -> int:
|
||||||
|
'''
|
||||||
|
Add two numbers together.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a: First number to add
|
||||||
|
b: Second number to add
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The sum of a and b
|
||||||
|
'''
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
await client.chat(model='llama3.1:8b', tools=[add_two_numbers], messages=[...])
|
||||||
|
|
||||||
Raises `RequestError` if a model is not provided.
|
Raises `RequestError` if a model is not provided.
|
||||||
|
|
||||||
Raises `ResponseError` if the request could not be fulfilled.
|
Raises `ResponseError` if the request could not be fulfilled.
|
||||||
@ -1075,9 +1127,9 @@ def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]]) -> Iterator[Tool]:
|
def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None) -> Iterator[Tool]:
|
||||||
for tool in tools or []:
|
for unprocessed_tool in tools or []:
|
||||||
yield Tool.model_validate(tool)
|
yield convert_function_to_tool(unprocessed_tool) if callable(unprocessed_tool) else Tool.model_validate(unprocessed_tool)
|
||||||
|
|
||||||
|
|
||||||
def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:
|
def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:
|
||||||
|
|||||||
@ -1,26 +1,18 @@
|
|||||||
import json
|
import json
|
||||||
from base64 import b64encode
|
from base64 import b64decode, b64encode
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import (
|
from typing import Any, Mapping, Optional, Union, Sequence
|
||||||
Any,
|
|
||||||
Literal,
|
from typing_extensions import Annotated, Literal
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
ByteSize,
|
ByteSize,
|
||||||
|
ConfigDict,
|
||||||
Field,
|
Field,
|
||||||
FilePath,
|
|
||||||
Base64Str,
|
|
||||||
model_serializer,
|
model_serializer,
|
||||||
)
|
)
|
||||||
from pydantic.json_schema import JsonSchemaValue
|
|
||||||
|
|
||||||
|
|
||||||
class SubscriptableBaseModel(BaseModel):
|
class SubscriptableBaseModel(BaseModel):
|
||||||
@ -95,16 +87,26 @@ class BaseGenerateRequest(BaseStreamableRequest):
|
|||||||
|
|
||||||
|
|
||||||
class Image(BaseModel):
|
class Image(BaseModel):
|
||||||
value: Union[FilePath, Base64Str, bytes]
|
value: Union[str, bytes, Path]
|
||||||
|
|
||||||
# This overloads the `model_dump` method and returns values depending on the type of the `value` field
|
|
||||||
@model_serializer
|
@model_serializer
|
||||||
def serialize_model(self):
|
def serialize_model(self):
|
||||||
if isinstance(self.value, Path):
|
if isinstance(self.value, (Path, bytes)):
|
||||||
return b64encode(self.value.read_bytes()).decode()
|
return b64encode(self.value.read_bytes() if isinstance(self.value, Path) else self.value).decode()
|
||||||
elif isinstance(self.value, bytes):
|
|
||||||
return b64encode(self.value).decode()
|
if isinstance(self.value, str):
|
||||||
return self.value
|
if Path(self.value).exists():
|
||||||
|
return b64encode(Path(self.value).read_bytes()).decode()
|
||||||
|
|
||||||
|
if self.value.split('.')[-1] in ('png', 'jpg', 'jpeg', 'webp'):
|
||||||
|
raise ValueError(f'File {self.value} does not exist')
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to decode to check if it's already base64
|
||||||
|
b64decode(self.value)
|
||||||
|
return self.value
|
||||||
|
except Exception:
|
||||||
|
raise ValueError('Invalid image data, expected base64 string or path to image file') from Exception
|
||||||
|
|
||||||
|
|
||||||
class GenerateRequest(BaseGenerateRequest):
|
class GenerateRequest(BaseGenerateRequest):
|
||||||
@ -222,20 +224,27 @@ class Message(SubscriptableBaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Tool(SubscriptableBaseModel):
|
class Tool(SubscriptableBaseModel):
|
||||||
type: Literal['function'] = 'function'
|
type: Optional[Literal['function']] = 'function'
|
||||||
|
|
||||||
class Function(SubscriptableBaseModel):
|
class Function(SubscriptableBaseModel):
|
||||||
name: str
|
name: Optional[str] = None
|
||||||
description: str
|
description: Optional[str] = None
|
||||||
|
|
||||||
class Parameters(SubscriptableBaseModel):
|
class Parameters(SubscriptableBaseModel):
|
||||||
type: str
|
type: Optional[Literal['object']] = 'object'
|
||||||
required: Optional[Sequence[str]] = None
|
required: Optional[Sequence[str]] = None
|
||||||
properties: Optional[JsonSchemaValue] = None
|
|
||||||
|
|
||||||
parameters: Parameters
|
class Property(SubscriptableBaseModel):
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
function: Function
|
type: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
properties: Optional[Mapping[str, Property]] = None
|
||||||
|
|
||||||
|
parameters: Optional[Parameters] = None
|
||||||
|
|
||||||
|
function: Optional[Function] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseGenerateRequest):
|
class ChatRequest(BaseGenerateRequest):
|
||||||
@ -335,6 +344,7 @@ class ModelDetails(SubscriptableBaseModel):
|
|||||||
|
|
||||||
class ListResponse(SubscriptableBaseModel):
|
class ListResponse(SubscriptableBaseModel):
|
||||||
class Model(SubscriptableBaseModel):
|
class Model(SubscriptableBaseModel):
|
||||||
|
model: Optional[str] = None
|
||||||
modified_at: Optional[datetime] = None
|
modified_at: Optional[datetime] = None
|
||||||
digest: Optional[str] = None
|
digest: Optional[str] = None
|
||||||
size: Optional[ByteSize] = None
|
size: Optional[ByteSize] = None
|
||||||
|
|||||||
87
ollama/_utils.py
Normal file
87
ollama/_utils.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from collections import defaultdict
|
||||||
|
import inspect
|
||||||
|
from typing import Callable, Union
|
||||||
|
import re
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
from ollama._types import Tool
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
|
||||||
|
parsed_docstring = defaultdict(str)
|
||||||
|
if not doc_string:
|
||||||
|
return parsed_docstring
|
||||||
|
|
||||||
|
key = hash(doc_string)
|
||||||
|
for line in doc_string.splitlines():
|
||||||
|
lowered_line = line.lower().strip()
|
||||||
|
if lowered_line.startswith('args:'):
|
||||||
|
key = 'args'
|
||||||
|
elif lowered_line.startswith('returns:') or lowered_line.startswith('yields:') or lowered_line.startswith('raises:'):
|
||||||
|
key = '_'
|
||||||
|
|
||||||
|
else:
|
||||||
|
# maybe change to a list and join later
|
||||||
|
parsed_docstring[key] += f'{line.strip()}\n'
|
||||||
|
|
||||||
|
last_key = None
|
||||||
|
for line in parsed_docstring['args'].splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if ':' in line:
|
||||||
|
# Split the line on either:
|
||||||
|
# 1. A parenthetical expression like (integer) - captured in group 1
|
||||||
|
# 2. A colon :
|
||||||
|
# Followed by optional whitespace. Only split on first occurrence.
|
||||||
|
parts = re.split(r'(?:\(([^)]*)\)|:)\s*', line, maxsplit=1)
|
||||||
|
|
||||||
|
arg_name = parts[0].strip()
|
||||||
|
last_key = arg_name
|
||||||
|
|
||||||
|
# Get the description - will be in parts[1] if parenthetical or parts[-1] if after colon
|
||||||
|
arg_description = parts[-1].strip()
|
||||||
|
if len(parts) > 2 and parts[1]: # Has parenthetical content
|
||||||
|
arg_description = parts[-1].split(':', 1)[-1].strip()
|
||||||
|
|
||||||
|
parsed_docstring[last_key] = arg_description
|
||||||
|
|
||||||
|
elif last_key and line:
|
||||||
|
parsed_docstring[last_key] += ' ' + line
|
||||||
|
|
||||||
|
return parsed_docstring
|
||||||
|
|
||||||
|
|
||||||
|
def convert_function_to_tool(func: Callable) -> Tool:
|
||||||
|
doc_string_hash = hash(inspect.getdoc(func))
|
||||||
|
parsed_docstring = _parse_docstring(inspect.getdoc(func))
|
||||||
|
schema = type(
|
||||||
|
func.__name__,
|
||||||
|
(pydantic.BaseModel,),
|
||||||
|
{
|
||||||
|
'__annotations__': {k: v.annotation if v.annotation != inspect._empty else str for k, v in inspect.signature(func).parameters.items()},
|
||||||
|
'__signature__': inspect.signature(func),
|
||||||
|
'__doc__': parsed_docstring[doc_string_hash],
|
||||||
|
},
|
||||||
|
).model_json_schema()
|
||||||
|
|
||||||
|
for k, v in schema.get('properties', {}).items():
|
||||||
|
# If type is missing, the default is string
|
||||||
|
types = {t.get('type', 'string') for t in v.get('anyOf')} if 'anyOf' in v else {v.get('type', 'string')}
|
||||||
|
if 'null' in types:
|
||||||
|
schema['required'].remove(k)
|
||||||
|
types.discard('null')
|
||||||
|
|
||||||
|
schema['properties'][k] = {
|
||||||
|
'description': parsed_docstring[k],
|
||||||
|
'type': ', '.join(types),
|
||||||
|
}
|
||||||
|
|
||||||
|
tool = Tool(
|
||||||
|
function=Tool.Function(
|
||||||
|
name=func.__name__,
|
||||||
|
description=schema.get('description', ''),
|
||||||
|
parameters=Tool.Function.Parameters(**schema),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return Tool.model_validate(tool)
|
||||||
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
from pydantic import ValidationError
|
||||||
import pytest
|
import pytest
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -8,7 +9,7 @@ from pytest_httpserver import HTTPServer, URIPattern
|
|||||||
from werkzeug.wrappers import Request, Response
|
from werkzeug.wrappers import Request, Response
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ollama._client import Client, AsyncClient
|
from ollama._client import Client, AsyncClient, _copy_tools
|
||||||
|
|
||||||
|
|
||||||
class PrefixPattern(URIPattern):
|
class PrefixPattern(URIPattern):
|
||||||
@ -982,3 +983,56 @@ def test_headers():
|
|||||||
)
|
)
|
||||||
assert client._client.headers['x-custom'] == 'value'
|
assert client._client.headers['x-custom'] == 'value'
|
||||||
assert client._client.headers['content-type'] == 'application/json'
|
assert client._client.headers['content-type'] == 'application/json'
|
||||||
|
|
||||||
|
|
||||||
|
def test_copy_tools():
|
||||||
|
def func1(x: int) -> str:
|
||||||
|
"""Simple function 1.
|
||||||
|
Args:
|
||||||
|
x (integer): A number
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def func2(y: str) -> int:
|
||||||
|
"""Simple function 2.
|
||||||
|
Args:
|
||||||
|
y (string): A string
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Test with list of functions
|
||||||
|
tools = list(_copy_tools([func1, func2]))
|
||||||
|
assert len(tools) == 2
|
||||||
|
assert tools[0].function.name == 'func1'
|
||||||
|
assert tools[1].function.name == 'func2'
|
||||||
|
|
||||||
|
# Test with empty input
|
||||||
|
assert list(_copy_tools()) == []
|
||||||
|
assert list(_copy_tools(None)) == []
|
||||||
|
assert list(_copy_tools([])) == []
|
||||||
|
|
||||||
|
# Test with mix of functions and tool dicts
|
||||||
|
tool_dict = {
|
||||||
|
'type': 'function',
|
||||||
|
'function': {
|
||||||
|
'name': 'test',
|
||||||
|
'description': 'Test function',
|
||||||
|
'parameters': {
|
||||||
|
'type': 'object',
|
||||||
|
'properties': {'x': {'type': 'string', 'description': 'A string'}},
|
||||||
|
'required': ['x'],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tools = list(_copy_tools([func1, tool_dict]))
|
||||||
|
assert len(tools) == 2
|
||||||
|
assert tools[0].function.name == 'func1'
|
||||||
|
assert tools[1].function.name == 'test'
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_validation():
|
||||||
|
# Raises ValidationError when used as it is a generator
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}}
|
||||||
|
list(_copy_tools([invalid_tool]))
|
||||||
|
|||||||
@ -1,15 +1,48 @@
|
|||||||
from base64 import b64decode, b64encode
|
from base64 import b64encode
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
from ollama._types import Image
|
from ollama._types import Image
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
|
||||||
def test_image_serialization():
|
def test_image_serialization_bytes():
|
||||||
# Test bytes serialization
|
|
||||||
image_bytes = b'test image bytes'
|
image_bytes = b'test image bytes'
|
||||||
|
encoded_string = b64encode(image_bytes).decode()
|
||||||
img = Image(value=image_bytes)
|
img = Image(value=image_bytes)
|
||||||
assert img.model_dump() == b64encode(image_bytes).decode()
|
assert img.model_dump() == encoded_string
|
||||||
|
|
||||||
# Test base64 string serialization
|
|
||||||
|
def test_image_serialization_base64_string():
|
||||||
b64_str = 'dGVzdCBiYXNlNjQgc3RyaW5n'
|
b64_str = 'dGVzdCBiYXNlNjQgc3RyaW5n'
|
||||||
img = Image(value=b64_str)
|
img = Image(value=b64_str)
|
||||||
assert img.model_dump() == b64decode(b64_str).decode()
|
assert img.model_dump() == b64_str # Should return as-is if valid base64
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_serialization_plain_string():
|
||||||
|
img = Image(value='not a path or base64')
|
||||||
|
assert img.model_dump() == 'not a path or base64' # Should return as-is
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_serialization_path():
|
||||||
|
with tempfile.NamedTemporaryFile() as temp_file:
|
||||||
|
temp_file.write(b'test file content')
|
||||||
|
temp_file.flush()
|
||||||
|
img = Image(value=Path(temp_file.name))
|
||||||
|
assert img.model_dump() == b64encode(b'test file content').decode()
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_serialization_string_path():
|
||||||
|
with tempfile.NamedTemporaryFile() as temp_file:
|
||||||
|
temp_file.write(b'test file content')
|
||||||
|
temp_file.flush()
|
||||||
|
img = Image(value=temp_file.name)
|
||||||
|
assert img.model_dump() == b64encode(b'test file content').decode()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
img = Image(value='some_path/that/does/not/exist.png')
|
||||||
|
img.model_dump()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
img = Image(value='not an image')
|
||||||
|
img.model_dump()
|
||||||
|
|||||||
270
tests/test_utils.py
Normal file
270
tests/test_utils.py
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from typing import Dict, List, Mapping, Sequence, Set, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
from ollama._utils import convert_function_to_tool
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_to_tool_conversion():
|
||||||
|
def add_numbers(x: int, y: Union[int, None] = None) -> int:
|
||||||
|
"""Add two numbers together.
|
||||||
|
args:
|
||||||
|
x (integer): The first number
|
||||||
|
y (integer, optional): The second number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
integer: The sum of x and y
|
||||||
|
"""
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
tool = convert_function_to_tool(add_numbers).model_dump()
|
||||||
|
|
||||||
|
assert tool['type'] == 'function'
|
||||||
|
assert tool['function']['name'] == 'add_numbers'
|
||||||
|
assert tool['function']['description'] == 'Add two numbers together.'
|
||||||
|
assert tool['function']['parameters']['type'] == 'object'
|
||||||
|
assert tool['function']['parameters']['properties']['x']['type'] == 'integer'
|
||||||
|
assert tool['function']['parameters']['properties']['x']['description'] == 'The first number'
|
||||||
|
assert tool['function']['parameters']['required'] == ['x']
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_with_no_args():
|
||||||
|
def simple_func():
|
||||||
|
"""
|
||||||
|
A simple function with no arguments.
|
||||||
|
Args:
|
||||||
|
None
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = convert_function_to_tool(simple_func).model_dump()
|
||||||
|
assert tool['function']['name'] == 'simple_func'
|
||||||
|
assert tool['function']['description'] == 'A simple function with no arguments.'
|
||||||
|
assert tool['function']['parameters']['properties'] == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_with_all_types():
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
|
||||||
|
def all_types(
|
||||||
|
x: int,
|
||||||
|
y: str,
|
||||||
|
z: list[int],
|
||||||
|
w: dict[str, int],
|
||||||
|
v: int | str | None,
|
||||||
|
) -> int | dict[str, int] | str | list[int] | None:
|
||||||
|
"""
|
||||||
|
A function with all types.
|
||||||
|
Args:
|
||||||
|
x (integer): The first number
|
||||||
|
y (string): The second number
|
||||||
|
z (array): The third number
|
||||||
|
w (object): The fourth number
|
||||||
|
v (integer | string | None): The fifth number
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
|
||||||
|
def all_types(
|
||||||
|
x: int,
|
||||||
|
y: str,
|
||||||
|
z: Sequence,
|
||||||
|
w: Mapping[str, int],
|
||||||
|
d: Dict[str, int],
|
||||||
|
s: Set[int],
|
||||||
|
t: Tuple[int, str],
|
||||||
|
l: List[int], # noqa: E741
|
||||||
|
o: Union[int, None],
|
||||||
|
) -> Union[Mapping[str, int], str, None]:
|
||||||
|
"""
|
||||||
|
A function with all types.
|
||||||
|
Args:
|
||||||
|
x (integer): The first number
|
||||||
|
y (string): The second number
|
||||||
|
z (array): The third number
|
||||||
|
w (object): The fourth number
|
||||||
|
d (object): The fifth number
|
||||||
|
s (array): The sixth number
|
||||||
|
t (array): The seventh number
|
||||||
|
l (array): The eighth number
|
||||||
|
o (integer | None): The ninth number
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool_json = convert_function_to_tool(all_types).model_dump_json()
|
||||||
|
tool = json.loads(tool_json)
|
||||||
|
assert tool['function']['parameters']['properties']['x']['type'] == 'integer'
|
||||||
|
assert tool['function']['parameters']['properties']['y']['type'] == 'string'
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
assert tool['function']['parameters']['properties']['z']['type'] == 'array'
|
||||||
|
assert tool['function']['parameters']['properties']['w']['type'] == 'object'
|
||||||
|
assert set(x.strip().strip("'") for x in tool['function']['parameters']['properties']['v']['type'].removeprefix('[').removesuffix(']').split(',')) == {'string', 'integer'}
|
||||||
|
assert tool['function']['parameters']['properties']['v']['type'] != 'null'
|
||||||
|
assert tool['function']['parameters']['required'] == ['x', 'y', 'z', 'w']
|
||||||
|
else:
|
||||||
|
assert tool['function']['parameters']['properties']['z']['type'] == 'array'
|
||||||
|
assert tool['function']['parameters']['properties']['w']['type'] == 'object'
|
||||||
|
assert tool['function']['parameters']['properties']['d']['type'] == 'object'
|
||||||
|
assert tool['function']['parameters']['properties']['s']['type'] == 'array'
|
||||||
|
assert tool['function']['parameters']['properties']['t']['type'] == 'array'
|
||||||
|
assert tool['function']['parameters']['properties']['l']['type'] == 'array'
|
||||||
|
assert tool['function']['parameters']['properties']['o']['type'] == 'integer'
|
||||||
|
assert tool['function']['parameters']['properties']['o']['type'] != 'null'
|
||||||
|
assert tool['function']['parameters']['required'] == ['x', 'y', 'z', 'w', 'd', 's', 't', 'l']
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_docstring_parsing():
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
def func_with_complex_docs(x: int, y: List[str]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Test function with complex docstring.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (integer): A number
|
||||||
|
with multiple lines
|
||||||
|
y (array of string): A list
|
||||||
|
with multiple lines
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: A dictionary
|
||||||
|
with multiple lines
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = convert_function_to_tool(func_with_complex_docs).model_dump()
|
||||||
|
assert tool['function']['description'] == 'Test function with complex docstring.'
|
||||||
|
assert tool['function']['parameters']['properties']['x']['description'] == 'A number with multiple lines'
|
||||||
|
assert tool['function']['parameters']['properties']['y']['description'] == 'A list with multiple lines'
|
||||||
|
|
||||||
|
|
||||||
|
def test_skewed_docstring_parsing():
|
||||||
|
def add_two_numbers(x: int, y: int) -> int:
|
||||||
|
"""
|
||||||
|
Add two numbers together.
|
||||||
|
Args:
|
||||||
|
x (integer): : The first number
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
y (integer ): The second number
|
||||||
|
Returns:
|
||||||
|
integer: The sum of x and y
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = convert_function_to_tool(add_two_numbers).model_dump()
|
||||||
|
assert tool['function']['parameters']['properties']['x']['description'] == ': The first number'
|
||||||
|
assert tool['function']['parameters']['properties']['y']['description'] == 'The second number'
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_with_no_docstring():
|
||||||
|
def no_docstring():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def no_docstring_with_args(x: int, y: int):
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = convert_function_to_tool(no_docstring).model_dump()
|
||||||
|
assert tool['function']['description'] == ''
|
||||||
|
|
||||||
|
tool = convert_function_to_tool(no_docstring_with_args).model_dump()
|
||||||
|
assert tool['function']['description'] == ''
|
||||||
|
assert tool['function']['parameters']['properties']['x']['description'] == ''
|
||||||
|
assert tool['function']['parameters']['properties']['y']['description'] == ''
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_with_only_description():
|
||||||
|
def only_description():
|
||||||
|
"""
|
||||||
|
A function with only a description.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = convert_function_to_tool(only_description).model_dump()
|
||||||
|
assert tool['function']['description'] == 'A function with only a description.'
|
||||||
|
assert tool['function']['parameters'] == {'type': 'object', 'properties': {}, 'required': None}
|
||||||
|
|
||||||
|
def only_description_with_args(x: int, y: int):
|
||||||
|
"""
|
||||||
|
A function with only a description.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = convert_function_to_tool(only_description_with_args).model_dump()
|
||||||
|
assert tool['function']['description'] == 'A function with only a description.'
|
||||||
|
assert tool['function']['parameters'] == {
|
||||||
|
'type': 'object',
|
||||||
|
'properties': {
|
||||||
|
'x': {'type': 'integer', 'description': ''},
|
||||||
|
'y': {'type': 'integer', 'description': ''},
|
||||||
|
},
|
||||||
|
'required': ['x', 'y'],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_with_yields():
|
||||||
|
def function_with_yields(x: int, y: int):
|
||||||
|
"""
|
||||||
|
A function with yields section.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: the first number
|
||||||
|
y: the second number
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The sum of x and y
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = convert_function_to_tool(function_with_yields).model_dump()
|
||||||
|
assert tool['function']['description'] == 'A function with yields section.'
|
||||||
|
assert tool['function']['parameters']['properties']['x']['description'] == 'the first number'
|
||||||
|
assert tool['function']['parameters']['properties']['y']['description'] == 'the second number'
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_with_no_types():
|
||||||
|
def no_types(a, b):
|
||||||
|
"""
|
||||||
|
A function with no types.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = convert_function_to_tool(no_types).model_dump()
|
||||||
|
assert tool['function']['parameters']['properties']['a']['type'] == 'string'
|
||||||
|
assert tool['function']['parameters']['properties']['b']['type'] == 'string'
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_with_parentheses():
|
||||||
|
def func_with_parentheses(a: int, b: int) -> int:
|
||||||
|
"""
|
||||||
|
A function with parentheses.
|
||||||
|
Args:
|
||||||
|
a: First (:thing) number to add
|
||||||
|
b: Second number to add
|
||||||
|
Returns:
|
||||||
|
int: The sum of a and b
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def func_with_parentheses_and_args(a: int, b: int):
|
||||||
|
"""
|
||||||
|
A function with parentheses and args.
|
||||||
|
Args:
|
||||||
|
a(integer) : First (:thing) number to add
|
||||||
|
b(integer) :Second number to add
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = convert_function_to_tool(func_with_parentheses).model_dump()
|
||||||
|
assert tool['function']['parameters']['properties']['a']['description'] == 'First (:thing) number to add'
|
||||||
|
assert tool['function']['parameters']['properties']['b']['description'] == 'Second number to add'
|
||||||
|
|
||||||
|
tool = convert_function_to_tool(func_with_parentheses_and_args).model_dump()
|
||||||
|
assert tool['function']['parameters']['properties']['a']['description'] == 'First (:thing) number to add'
|
||||||
|
assert tool['function']['parameters']['properties']['b']['description'] == 'Second number to add'
|
||||||
Loading…
Reference in New Issue
Block a user