mirror of
https://github.com/ollama/ollama-python.git
synced 2026-04-25 08:48:18 +08:00
Managing multiple type options
This commit is contained in:
parent
438360339c
commit
afe7db65be
@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
@ -229,7 +230,7 @@ class Tool(SubscriptableBaseModel):
|
||||
description: str
|
||||
|
||||
class Parameters(SubscriptableBaseModel):
|
||||
type: str
|
||||
type: Union[str, List[str]]
|
||||
required: Optional[Sequence[str]] = None
|
||||
properties: Optional[JsonSchemaValue] = None
|
||||
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union, get_args, get_origin
|
||||
from ollama._types import Tool
|
||||
from icecream import ic
|
||||
|
||||
|
||||
PYTHON_TO_JSON_TYPES = {
|
||||
str: 'string',
|
||||
@ -16,18 +14,27 @@ PYTHON_TO_JSON_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
def _get_json_type(python_type: Any) -> str:
|
||||
def _get_json_type(python_type: Any) -> str | list[str]:
|
||||
"""Convert Python type to JSON schema type."""
|
||||
# Handle Optional types (Union[type, None])
|
||||
# Handle Optional types (Union[type, None] and type | None)
|
||||
origin = get_origin(python_type)
|
||||
if origin is Union:
|
||||
print('python_type', python_type)
|
||||
args = get_args(python_type)
|
||||
if len(args) == 2 and args[1] is type(None): # Optional case
|
||||
return _get_json_type(args[0])
|
||||
# Filter out None/NoneType from union args
|
||||
non_none_args = [arg for arg in args if arg not in (None, type(None))]
|
||||
non_none_args_types = [PYTHON_TO_JSON_TYPES[arg] for arg in non_none_args]
|
||||
if non_none_args:
|
||||
if len(non_none_args) == 1:
|
||||
return _get_json_type(non_none_args[0])
|
||||
else:
|
||||
return non_none_args_types
|
||||
|
||||
return 'null'
|
||||
|
||||
# Get basic type mapping
|
||||
if python_type in PYTHON_TO_JSON_TYPES:
|
||||
print('python_type', python_type, PYTHON_TO_JSON_TYPES[python_type])
|
||||
return PYTHON_TO_JSON_TYPES[python_type]
|
||||
|
||||
# Handle typing.List, typing.Dict etc.
|
||||
@ -38,6 +45,15 @@ def _get_json_type(python_type: Any) -> str:
|
||||
return 'string'
|
||||
|
||||
|
||||
def _is_optional_type(python_type: Any) -> bool:
|
||||
"""Check if a type is optional (can be None)."""
|
||||
origin = get_origin(python_type)
|
||||
if origin is Union:
|
||||
args = get_args(python_type)
|
||||
return any(arg in (None, type(None)) for arg in args)
|
||||
return False
|
||||
|
||||
|
||||
def convert_function_to_tool(func: Callable) -> Tool:
|
||||
doc_string = func.__doc__
|
||||
if not doc_string:
|
||||
@ -85,7 +101,10 @@ def convert_function_to_tool(func: Callable) -> Tool:
|
||||
'type': _get_json_type(param_type),
|
||||
'description': param_desc,
|
||||
}
|
||||
parameters['required'].append(param_name)
|
||||
|
||||
# Only add to required if not optional
|
||||
if not _is_optional_type(param_type):
|
||||
parameters['required'].append(param_name)
|
||||
|
||||
tool_dict = {
|
||||
'type': 'function',
|
||||
@ -95,7 +114,6 @@ def convert_function_to_tool(func: Callable) -> Tool:
|
||||
'parameters': parameters,
|
||||
},
|
||||
}
|
||||
print('descs', tool_dict['function']['description'])
|
||||
return Tool.model_validate(tool_dict)
|
||||
|
||||
|
||||
@ -111,7 +129,6 @@ def process_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callab
|
||||
for tool in tools:
|
||||
if callable(tool):
|
||||
processed_tools.append(convert_function_to_tool(tool))
|
||||
ic('processed tool', processed_tools[-1].model_dump())
|
||||
else:
|
||||
# Existing tool handling logic
|
||||
processed_tools.append(Tool.model_validate(tool))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user