Compare commits

..

1 Commits

Author SHA1 Message Date
Josh Yan 98b3a44b14 include tool calls 2024-07-16 11:39:42 -07:00
7 changed files with 23 additions and 222 deletions
+3 -3
View File
@@ -1,16 +1,16 @@
from ollama import generate
prompt = '''def remove_non_ascii(s: str) -> str:
prefix = '''def remove_non_ascii(s: str) -> str:
""" '''
suffix = """
return result
"""
response = generate(
model='codellama:7b-code',
prompt=prompt,
suffix=suffix,
prompt=f'<PRE> {prefix} <SUF>{suffix} <MID>',
options={
'num_predict': 128,
'temperature': 0,
-3
View File
@@ -1,3 +0,0 @@
# tools
This example demonstrates how to utilize tool calls with an asynchronous Ollama client and the chat endpoint.
-87
View File
@@ -1,87 +0,0 @@
import json
import ollama
import asyncio
# Simulates an API call to get flight times
# In a real application, this would fetch data from a live database or API
def get_flight_times(departure: str, arrival: str) -> str:
flights = {
'NYC-LAX': {'departure': '08:00 AM', 'arrival': '11:30 AM', 'duration': '5h 30m'},
'LAX-NYC': {'departure': '02:00 PM', 'arrival': '10:30 PM', 'duration': '5h 30m'},
'LHR-JFK': {'departure': '10:00 AM', 'arrival': '01:00 PM', 'duration': '8h 00m'},
'JFK-LHR': {'departure': '09:00 PM', 'arrival': '09:00 AM', 'duration': '7h 00m'},
'CDG-DXB': {'departure': '11:00 AM', 'arrival': '08:00 PM', 'duration': '6h 00m'},
'DXB-CDG': {'departure': '03:00 AM', 'arrival': '07:30 AM', 'duration': '7h 30m'},
}
key = f'{departure}-{arrival}'.upper()
return json.dumps(flights.get(key, {'error': 'Flight not found'}))
async def run(model: str):
client = ollama.AsyncClient()
# Initialize conversation with a user query
messages = [{'role': 'user', 'content': 'What is the flight time from New York (NYC) to Los Angeles (LAX)?'}]
# First API call: Send the query and function description to the model
response = await client.chat(
model=model,
messages=messages,
tools=[
{
'type': 'function',
'function': {
'name': 'get_flight_times',
'description': 'Get the flight times between two cities',
'parameters': {
'type': 'object',
'properties': {
'departure': {
'type': 'string',
'description': 'The departure city (airport code)',
},
'arrival': {
'type': 'string',
'description': 'The arrival city (airport code)',
},
},
'required': ['departure', 'arrival'],
},
},
},
],
)
# Add the model's response to the conversation history
messages.append(response['message'])
# Check if the model decided to use the provided function
if not response['message'].get('tool_calls'):
print("The model didn't use the function. Its response was:")
print(response['message']['content'])
return
# Process function calls made by the model
if response['message'].get('tool_calls'):
available_functions = {
'get_flight_times': get_flight_times,
}
for tool in response['message']['tool_calls']:
function_to_call = available_functions[tool['function']['name']]
function_response = function_to_call(tool['function']['arguments']['departure'], tool['function']['arguments']['arrival'])
# Add function response to the conversation
messages.append(
{
'role': 'tool',
'content': function_response,
}
)
# Second API call: Get final response from the model
final_response = await client.chat(model=model, messages=messages)
print(final_response['message']['content'])
# Run the async function
asyncio.run(run('mistral'))
-2
View File
@@ -21,7 +21,6 @@ __all__ = [
'ResponseError',
'generate',
'chat',
'embed',
'embeddings',
'pull',
'push',
@@ -37,7 +36,6 @@ _client = Client()
generate = _client.generate
chat = _client.chat
embed = _client.embed
embeddings = _client.embeddings
pull = _client.pull
push = _client.push
+17 -65
View File
@@ -27,7 +27,7 @@ try:
except metadata.PackageNotFoundError:
__version__ = '0.0.0'
from ollama._types import Message, Options, RequestError, ResponseError, Tool
from ollama._types import Message, Options, RequestError, ResponseError
class BaseClient:
@@ -102,7 +102,6 @@ class Client(BaseClient):
self,
model: str = '',
prompt: str = '',
suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
@@ -110,6 +109,7 @@ class Client(BaseClient):
raw: bool = False,
format: Literal['', 'json'] = '',
images: Optional[Sequence[AnyStr]] = None,
tools: Optional[Sequence[Any]] = None,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Mapping[str, Any]: ...
@@ -119,7 +119,6 @@ class Client(BaseClient):
self,
model: str = '',
prompt: str = '',
suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
@@ -127,6 +126,7 @@ class Client(BaseClient):
raw: bool = False,
format: Literal['', 'json'] = '',
images: Optional[Sequence[AnyStr]] = None,
tools: Optional[Sequence[Any]] = None,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Iterator[Mapping[str, Any]]: ...
@@ -135,7 +135,6 @@ class Client(BaseClient):
self,
model: str = '',
prompt: str = '',
suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
@@ -143,6 +142,7 @@ class Client(BaseClient):
raw: bool = False,
format: Literal['', 'json'] = '',
images: Optional[Sequence[AnyStr]] = None,
tools: Optional[Sequence[Any]] = None,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
@@ -165,13 +165,13 @@ class Client(BaseClient):
json={
'model': model,
'prompt': prompt,
'suffix': suffix,
'system': system,
'template': template,
'context': context or [],
'stream': stream,
'raw': raw,
'images': [_encode_image(image) for image in images or []],
'tools': tools or [],
'format': format,
'options': options or {},
'keep_alive': keep_alive,
@@ -184,7 +184,6 @@ class Client(BaseClient):
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[False] = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
@@ -196,7 +195,6 @@ class Client(BaseClient):
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[True] = True,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
@@ -207,7 +205,6 @@ class Client(BaseClient):
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: bool = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
@@ -229,6 +226,12 @@ class Client(BaseClient):
messages = deepcopy(messages)
for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of Message or dict-like objects')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
if 'content' not in message:
raise RequestError('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]
@@ -238,7 +241,6 @@ class Client(BaseClient):
json={
'model': model,
'messages': messages,
'tools': tools or [],
'stream': stream,
'format': format,
'options': options or {},
@@ -247,29 +249,6 @@ class Client(BaseClient):
stream=stream,
)
def embed(
self,
model: str = '',
input: Union[str, Sequence[AnyStr]] = '',
truncate: bool = True,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Mapping[str, Any]:
if not model:
raise RequestError('must provide a model')
return self._request(
'POST',
'/api/embed',
json={
'model': model,
'input': input,
'truncate': truncate,
'options': options or {},
'keep_alive': keep_alive,
},
).json()
def embeddings(
self,
model: str = '',
@@ -522,7 +501,6 @@ class AsyncClient(BaseClient):
self,
model: str = '',
prompt: str = '',
suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
@@ -539,7 +517,6 @@ class AsyncClient(BaseClient):
self,
model: str = '',
prompt: str = '',
suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
@@ -555,7 +532,6 @@ class AsyncClient(BaseClient):
self,
model: str = '',
prompt: str = '',
suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
@@ -584,7 +560,6 @@ class AsyncClient(BaseClient):
json={
'model': model,
'prompt': prompt,
'suffix': suffix,
'system': system,
'template': template,
'context': context or [],
@@ -603,7 +578,6 @@ class AsyncClient(BaseClient):
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[False] = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
@@ -615,7 +589,6 @@ class AsyncClient(BaseClient):
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[True] = True,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
@@ -626,7 +599,6 @@ class AsyncClient(BaseClient):
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: bool = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
@@ -647,6 +619,12 @@ class AsyncClient(BaseClient):
messages = deepcopy(messages)
for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of strings')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
if 'content' not in message:
raise RequestError('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]
@@ -656,7 +634,6 @@ class AsyncClient(BaseClient):
json={
'model': model,
'messages': messages,
'tools': tools or [],
'stream': stream,
'format': format,
'options': options or {},
@@ -665,31 +642,6 @@ class AsyncClient(BaseClient):
stream=stream,
)
async def embed(
self,
model: str = '',
input: Union[str, Sequence[AnyStr]] = '',
truncate: bool = True,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Mapping[str, Any]:
if not model:
raise RequestError('must provide a model')
response = await self._request(
'POST',
'/api/embed',
json={
'model': model,
'input': input,
'truncate': truncate,
'options': options or {},
'keep_alive': keep_alive,
},
)
return response.json()
async def embeddings(
self,
model: str = '',
+3 -50
View File
@@ -1,5 +1,5 @@
import json
from typing import Any, TypedDict, Sequence, Literal, Mapping
from typing import Any, TypedDict, Sequence, Literal
import sys
@@ -52,27 +52,8 @@ class GenerateResponse(BaseGenerateResponse):
context: Sequence[int]
'Tokenized history up to the point of the response.'
class ToolCallFunction(TypedDict):
"""
Tool call function.
"""
name: str
'Name of the function.'
args: NotRequired[Mapping[str, Any]]
'Arguments of the function.'
class ToolCall(TypedDict):
"""
Model tool calls.
"""
function: ToolCallFunction
'Function to be called.'
tool_calls: Sequence[Any]
'List of tool calls made by the model.'
class Message(TypedDict):
"""
@@ -97,34 +78,6 @@ class Message(TypedDict):
Valid image formats depend on the model. See the model card for more information.
"""
tool_calls: NotRequired[Sequence[ToolCall]]
"""
Tools calls to be made by the model.
"""
class Property(TypedDict):
type: str
description: str
enum: NotRequired[Sequence[str]] # `enum` is optional and can be a list of strings
class Parameters(TypedDict):
type: str
required: Sequence[str]
properties: Mapping[str, Property]
class ToolFunction(TypedDict):
name: str
description: str
parameters: Parameters
class Tool(TypedDict):
type: str
function: ToolFunction
class ChatResponse(BaseGenerateResponse):
"""
-12
View File
@@ -26,7 +26,6 @@ def test_client_chat(httpserver: HTTPServer):
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': False,
'format': '',
'options': {},
@@ -74,7 +73,6 @@ def test_client_chat_stream(httpserver: HTTPServer):
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': True,
'format': '',
'options': {},
@@ -104,7 +102,6 @@ def test_client_chat_images(httpserver: HTTPServer):
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
},
],
'tools': [],
'stream': False,
'format': '',
'options': {},
@@ -137,7 +134,6 @@ def test_client_generate(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '',
'template': '',
'context': [],
@@ -183,7 +179,6 @@ def test_client_generate_stream(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '',
'template': '',
'context': [],
@@ -212,7 +207,6 @@ def test_client_generate_images(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '',
'template': '',
'context': [],
@@ -528,7 +522,6 @@ async def test_async_client_chat(httpserver: HTTPServer):
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': False,
'format': '',
'options': {},
@@ -567,7 +560,6 @@ async def test_async_client_chat_stream(httpserver: HTTPServer):
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': True,
'format': '',
'options': {},
@@ -598,7 +590,6 @@ async def test_async_client_chat_images(httpserver: HTTPServer):
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
},
],
'tools': [],
'stream': False,
'format': '',
'options': {},
@@ -622,7 +613,6 @@ async def test_async_client_generate(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '',
'template': '',
'context': [],
@@ -663,7 +653,6 @@ async def test_async_client_generate_stream(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '',
'template': '',
'context': [],
@@ -693,7 +682,6 @@ async def test_async_client_generate_images(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '',
'template': '',
'context': [],