integrate tool calls (#213)

This commit is contained in:
Josh 2024-07-17 09:40:49 -07:00 committed by GitHub
parent 1a15742705
commit 359c63daa7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 155 additions and 14 deletions

3
examples/tools/README.md Normal file
View File

@ -0,0 +1,3 @@
# tools
This example demonstrates how to utilize tool calls with an asynchronous Ollama client and the chat endpoint.

87
examples/tools/main.py Normal file
View File

@ -0,0 +1,87 @@
import json
import ollama
import asyncio
# Simulates an API call to get flight times
# In a real application, this would fetch data from a live database or API
def get_flight_times(departure: str, arrival: str) -> str:
flights = {
'NYC-LAX': {'departure': '08:00 AM', 'arrival': '11:30 AM', 'duration': '5h 30m'},
'LAX-NYC': {'departure': '02:00 PM', 'arrival': '10:30 PM', 'duration': '5h 30m'},
'LHR-JFK': {'departure': '10:00 AM', 'arrival': '01:00 PM', 'duration': '8h 00m'},
'JFK-LHR': {'departure': '09:00 PM', 'arrival': '09:00 AM', 'duration': '7h 00m'},
'CDG-DXB': {'departure': '11:00 AM', 'arrival': '08:00 PM', 'duration': '6h 00m'},
'DXB-CDG': {'departure': '03:00 AM', 'arrival': '07:30 AM', 'duration': '7h 30m'},
}
key = f'{departure}-{arrival}'.upper()
return json.dumps(flights.get(key, {'error': 'Flight not found'}))
async def run(model: str):
client = ollama.AsyncClient()
# Initialize conversation with a user query
messages = [{'role': 'user', 'content': 'What is the flight time from New York (NYC) to Los Angeles (LAX)?'}]
# First API call: Send the query and function description to the model
response = await client.chat(
model=model,
messages=messages,
tools=[
{
'type': 'function',
'function': {
'name': 'get_flight_times',
'description': 'Get the flight times between two cities',
'parameters': {
'type': 'object',
'properties': {
'departure': {
'type': 'string',
'description': 'The departure city (airport code)',
},
'arrival': {
'type': 'string',
'description': 'The arrival city (airport code)',
},
},
'required': ['departure', 'arrival'],
},
},
},
],
)
# Add the model's response to the conversation history
messages.append(response['message'])
# Check if the model decided to use the provided function
if not response['message'].get('tool_calls'):
print("The model didn't use the function. Its response was:")
print(response['message']['content'])
return
# Process function calls made by the model
if response['message'].get('tool_calls'):
available_functions = {
'get_flight_times': get_flight_times,
}
for tool in response['message']['tool_calls']:
function_to_call = available_functions[tool['function']['name']]
function_response = function_to_call(tool['function']['arguments']['departure'], tool['function']['arguments']['arrival'])
# Add function response to the conversation
messages.append(
{
'role': 'tool',
'content': function_response,
}
)
# Second API call: Get final response from the model
final_response = await client.chat(model=model, messages=messages)
print(final_response['message']['content'])
# Run the async function
asyncio.run(run('mistral'))

View File

@ -27,7 +27,7 @@ try:
except metadata.PackageNotFoundError: except metadata.PackageNotFoundError:
__version__ = '0.0.0' __version__ = '0.0.0'
from ollama._types import Message, Options, RequestError, ResponseError from ollama._types import Message, Options, RequestError, ResponseError, Tool
class BaseClient: class BaseClient:
@ -180,6 +180,7 @@ class Client(BaseClient):
self, self,
model: str = '', model: str = '',
messages: Optional[Sequence[Message]] = None, messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[False] = False, stream: Literal[False] = False,
format: Literal['', 'json'] = '', format: Literal['', 'json'] = '',
options: Optional[Options] = None, options: Optional[Options] = None,
@ -191,6 +192,7 @@ class Client(BaseClient):
self, self,
model: str = '', model: str = '',
messages: Optional[Sequence[Message]] = None, messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[True] = True, stream: Literal[True] = True,
format: Literal['', 'json'] = '', format: Literal['', 'json'] = '',
options: Optional[Options] = None, options: Optional[Options] = None,
@ -201,6 +203,7 @@ class Client(BaseClient):
self, self,
model: str = '', model: str = '',
messages: Optional[Sequence[Message]] = None, messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: bool = False, stream: bool = False,
format: Literal['', 'json'] = '', format: Literal['', 'json'] = '',
options: Optional[Options] = None, options: Optional[Options] = None,
@ -222,12 +225,6 @@ class Client(BaseClient):
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages or []: for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of Message or dict-like objects')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
if 'content' not in message:
raise RequestError('messages must contain content')
if images := message.get('images'): if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images] message['images'] = [_encode_image(image) for image in images]
@ -237,6 +234,7 @@ class Client(BaseClient):
json={ json={
'model': model, 'model': model,
'messages': messages, 'messages': messages,
'tools': tools or [],
'stream': stream, 'stream': stream,
'format': format, 'format': format,
'options': options or {}, 'options': options or {},
@ -574,6 +572,7 @@ class AsyncClient(BaseClient):
self, self,
model: str = '', model: str = '',
messages: Optional[Sequence[Message]] = None, messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[False] = False, stream: Literal[False] = False,
format: Literal['', 'json'] = '', format: Literal['', 'json'] = '',
options: Optional[Options] = None, options: Optional[Options] = None,
@ -585,6 +584,7 @@ class AsyncClient(BaseClient):
self, self,
model: str = '', model: str = '',
messages: Optional[Sequence[Message]] = None, messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[True] = True, stream: Literal[True] = True,
format: Literal['', 'json'] = '', format: Literal['', 'json'] = '',
options: Optional[Options] = None, options: Optional[Options] = None,
@ -595,6 +595,7 @@ class AsyncClient(BaseClient):
self, self,
model: str = '', model: str = '',
messages: Optional[Sequence[Message]] = None, messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: bool = False, stream: bool = False,
format: Literal['', 'json'] = '', format: Literal['', 'json'] = '',
options: Optional[Options] = None, options: Optional[Options] = None,
@ -615,12 +616,6 @@ class AsyncClient(BaseClient):
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages or []: for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of strings')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
if 'content' not in message:
raise RequestError('messages must contain content')
if images := message.get('images'): if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images] message['images'] = [_encode_image(image) for image in images]
@ -630,6 +625,7 @@ class AsyncClient(BaseClient):
json={ json={
'model': model, 'model': model,
'messages': messages, 'messages': messages,
'tools': tools or [],
'stream': stream, 'stream': stream,
'format': format, 'format': format,
'options': options or {}, 'options': options or {},

View File

@ -1,5 +1,5 @@
import json import json
from typing import Any, TypedDict, Sequence, Literal from typing import Any, TypedDict, Sequence, Literal, Mapping
import sys import sys
@ -53,6 +53,27 @@ class GenerateResponse(BaseGenerateResponse):
'Tokenized history up to the point of the response.' 'Tokenized history up to the point of the response.'
class ToolCallFunction(TypedDict):
"""
Tool call function.
"""
name: str
'Name of the function.'
args: NotRequired[Mapping[str, Any]]
'Arguments of the function.'
class ToolCall(TypedDict):
"""
Model tool calls.
"""
function: ToolCallFunction
'Function to be called.'
class Message(TypedDict): class Message(TypedDict):
""" """
Chat message. Chat message.
@ -76,6 +97,34 @@ class Message(TypedDict):
Valid image formats depend on the model. See the model card for more information. Valid image formats depend on the model. See the model card for more information.
""" """
tool_calls: NotRequired[Sequence[ToolCall]]
"""
Tools calls to be made by the model.
"""
class Property(TypedDict):
type: str
description: str
enum: NotRequired[Sequence[str]] # `enum` is optional and can be a list of strings
class Parameters(TypedDict):
type: str
required: Sequence[str]
properties: Mapping[str, Property]
class ToolFunction(TypedDict):
name: str
description: str
parameters: Parameters
class Tool(TypedDict):
type: str
function: ToolFunction
class ChatResponse(BaseGenerateResponse): class ChatResponse(BaseGenerateResponse):
""" """

View File

@ -26,6 +26,7 @@ def test_client_chat(httpserver: HTTPServer):
json={ json={
'model': 'dummy', 'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': False, 'stream': False,
'format': '', 'format': '',
'options': {}, 'options': {},
@ -73,6 +74,7 @@ def test_client_chat_stream(httpserver: HTTPServer):
json={ json={
'model': 'dummy', 'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': True, 'stream': True,
'format': '', 'format': '',
'options': {}, 'options': {},
@ -102,6 +104,7 @@ def test_client_chat_images(httpserver: HTTPServer):
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
}, },
], ],
'tools': [],
'stream': False, 'stream': False,
'format': '', 'format': '',
'options': {}, 'options': {},
@ -522,6 +525,7 @@ async def test_async_client_chat(httpserver: HTTPServer):
json={ json={
'model': 'dummy', 'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': False, 'stream': False,
'format': '', 'format': '',
'options': {}, 'options': {},
@ -560,6 +564,7 @@ async def test_async_client_chat_stream(httpserver: HTTPServer):
json={ json={
'model': 'dummy', 'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': True, 'stream': True,
'format': '', 'format': '',
'options': {}, 'options': {},
@ -590,6 +595,7 @@ async def test_async_client_chat_images(httpserver: HTTPServer):
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
}, },
], ],
'tools': [],
'stream': False, 'stream': False,
'format': '', 'format': '',
'options': {}, 'options': {},