diff --git a/ollama/_types.py b/ollama/_types.py index 653e122..3b2452e 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -4,6 +4,7 @@ from pathlib import Path from datetime import datetime from typing import ( Any, + List, Literal, Mapping, Optional, @@ -229,7 +230,7 @@ class Tool(SubscriptableBaseModel): description: str class Parameters(SubscriptableBaseModel): - type: str + type: Union[str, List[str]] required: Optional[Sequence[str]] = None properties: Optional[JsonSchemaValue] = None diff --git a/ollama/_utils.py b/ollama/_utils.py index f4d14a0..b3de43f 100644 --- a/ollama/_utils.py +++ b/ollama/_utils.py @@ -1,7 +1,5 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union, get_args, get_origin from ollama._types import Tool -from icecream import ic - PYTHON_TO_JSON_TYPES = { str: 'string', @@ -16,18 +14,27 @@ PYTHON_TO_JSON_TYPES = { } -def _get_json_type(python_type: Any) -> str: +def _get_json_type(python_type: Any) -> str | list[str]: """Convert Python type to JSON schema type.""" - # Handle Optional types (Union[type, None]) + # Handle Optional types (Union[type, None] and type | None) origin = get_origin(python_type) if origin is Union: print('python_type', python_type) args = get_args(python_type) - if len(args) == 2 and args[1] is type(None): # Optional case - return _get_json_type(args[0]) + # Filter out None/NoneType from union args + non_none_args = [arg for arg in args if arg not in (None, type(None))] + non_none_args_types = [PYTHON_TO_JSON_TYPES[arg] for arg in non_none_args] + if non_none_args: + if len(non_none_args) == 1: + return _get_json_type(non_none_args[0]) + else: + return non_none_args_types + + return 'null' # Get basic type mapping if python_type in PYTHON_TO_JSON_TYPES: + print('python_type', python_type, PYTHON_TO_JSON_TYPES[python_type]) return PYTHON_TO_JSON_TYPES[python_type] # Handle typing.List, typing.Dict etc. @@ -38,6 +45,15 @@ def _get_json_type(python_type: Any) -> str: return 'string' +def _is_optional_type(python_type: Any) -> bool: + """Check if a type is optional (can be None).""" + origin = get_origin(python_type) + if origin is Union: + args = get_args(python_type) + return any(arg in (None, type(None)) for arg in args) + return False + + def convert_function_to_tool(func: Callable) -> Tool: doc_string = func.__doc__ if not doc_string: @@ -85,7 +101,10 @@ def convert_function_to_tool(func: Callable) -> Tool: 'type': _get_json_type(param_type), 'description': param_desc, } - parameters['required'].append(param_name) + + # Only add to required if not optional + if not _is_optional_type(param_type): + parameters['required'].append(param_name) tool_dict = { 'type': 'function', @@ -95,7 +114,6 @@ def convert_function_to_tool(func: Callable) -> Tool: 'parameters': parameters, }, } - print('descs', tool_dict['function']['description']) return Tool.model_validate(tool_dict) @@ -111,7 +129,6 @@ def process_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callab for tool in tools: if callable(tool): processed_tools.append(convert_function_to_tool(tool)) - ic('processed tool', processed_tools[-1].model_dump()) else: # Existing tool handling logic processed_tools.append(Tool.model_validate(tool))