mirror of
https://github.com/ollama/ollama-python.git
synced 2026-04-25 08:48:18 +08:00
WIP tool parsing
This commit is contained in:
parent
dc38fe4675
commit
438360339c
@ -335,6 +335,7 @@ class ModelDetails(SubscriptableBaseModel):
|
||||
|
||||
class ListResponse(SubscriptableBaseModel):
|
||||
class Model(SubscriptableBaseModel):
|
||||
name: Optional[str] = None
|
||||
modified_at: Optional[datetime] = None
|
||||
digest: Optional[str] = None
|
||||
size: Optional[ByteSize] = None
|
||||
|
||||
119
ollama/_utils.py
Normal file
119
ollama/_utils.py
Normal file
@ -0,0 +1,119 @@
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union, get_args, get_origin
|
||||
from ollama._types import Tool
|
||||
from icecream import ic
|
||||
|
||||
|
||||
PYTHON_TO_JSON_TYPES = {
|
||||
str: 'string',
|
||||
int: 'integer',
|
||||
float: 'number',
|
||||
bool: 'boolean',
|
||||
list: 'array',
|
||||
dict: 'object',
|
||||
List: 'array',
|
||||
Dict: 'object',
|
||||
None: 'null',
|
||||
}
|
||||
|
||||
|
||||
def _get_json_type(python_type: Any) -> str:
|
||||
"""Convert Python type to JSON schema type."""
|
||||
# Handle Optional types (Union[type, None])
|
||||
origin = get_origin(python_type)
|
||||
if origin is Union:
|
||||
print('python_type', python_type)
|
||||
args = get_args(python_type)
|
||||
if len(args) == 2 and args[1] is type(None): # Optional case
|
||||
return _get_json_type(args[0])
|
||||
|
||||
# Get basic type mapping
|
||||
if python_type in PYTHON_TO_JSON_TYPES:
|
||||
return PYTHON_TO_JSON_TYPES[python_type]
|
||||
|
||||
# Handle typing.List, typing.Dict etc.
|
||||
if origin in PYTHON_TO_JSON_TYPES:
|
||||
return PYTHON_TO_JSON_TYPES[origin]
|
||||
|
||||
# Default to string if type is unknown
|
||||
return 'string'
|
||||
|
||||
|
||||
def convert_function_to_tool(func: Callable) -> Tool:
|
||||
doc_string = func.__doc__
|
||||
if not doc_string:
|
||||
raise ValueError(f'Function {func.__name__} must have a docstring in Google format. Example:\n' '"""Add two numbers.\n\n' 'Args:\n' ' a: First number\n' ' b: Second number\n\n' 'Returns:\n' ' int: Sum of the numbers\n' '"""')
|
||||
|
||||
# Extract description from docstring - get all lines before Args:
|
||||
description_lines = []
|
||||
for line in doc_string.split('\n'):
|
||||
line = line.strip()
|
||||
if line.startswith('Args:'):
|
||||
break
|
||||
if line:
|
||||
description_lines.append(line)
|
||||
description = ' '.join(description_lines).strip()
|
||||
|
||||
# Parse Args section
|
||||
if 'Args:' not in doc_string:
|
||||
raise ValueError(f'Function {func.__name__} docstring must have an Args section in Google format')
|
||||
|
||||
args_section = doc_string.split('Args:')[1]
|
||||
if 'Returns:' in args_section:
|
||||
args_section = args_section.split('Returns:')[0]
|
||||
|
||||
# Build parameters from function annotations
|
||||
parameters = {'type': 'object', 'properties': {}, 'required': []}
|
||||
|
||||
# Build parameters dict
|
||||
for param_name, param_type in func.__annotations__.items():
|
||||
if param_name == 'return':
|
||||
continue
|
||||
|
||||
# Find param description in Args section
|
||||
param_desc = None
|
||||
for line in args_section.split('\n'):
|
||||
line = line.strip()
|
||||
# Check for parameter name with or without colon, space, or parentheses to mitigate formatting issues
|
||||
if line.startswith(param_name + ':') or line.startswith(param_name + ' ') or line.startswith(param_name + '('):
|
||||
param_desc = line.split(':', 1)[1].strip()
|
||||
break
|
||||
|
||||
if not param_desc:
|
||||
raise ValueError(f'Parameter {param_name} must have a description in the Args section')
|
||||
|
||||
parameters['properties'][param_name] = {
|
||||
'type': _get_json_type(param_type),
|
||||
'description': param_desc,
|
||||
}
|
||||
parameters['required'].append(param_name)
|
||||
|
||||
tool_dict = {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': func.__name__,
|
||||
'description': description,
|
||||
'parameters': parameters,
|
||||
},
|
||||
}
|
||||
print('descs', tool_dict['function']['description'])
|
||||
return Tool.model_validate(tool_dict)
|
||||
|
||||
|
||||
def process_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None) -> Sequence[Tool]:
|
||||
"""
|
||||
Process a sequence of tools that can be mappings, Tool objects, or callable functions.
|
||||
Returns a sequence of validated Tool objects.
|
||||
"""
|
||||
if not tools:
|
||||
return []
|
||||
|
||||
processed_tools = []
|
||||
for tool in tools:
|
||||
if callable(tool):
|
||||
processed_tools.append(convert_function_to_tool(tool))
|
||||
ic('processed tool', processed_tools[-1].model_dump())
|
||||
else:
|
||||
# Existing tool handling logic
|
||||
processed_tools.append(Tool.model_validate(tool))
|
||||
|
||||
return processed_tools
|
||||
Loading…
Reference in New Issue
Block a user