Compare commits

..

6 Commits

Author SHA1 Message Date
royjhan 33c4b61ff9 add insert support to generate endpoint (#215)
* add suffix

* update fill-in-the-middle example

* keep example

* lint

* variables
2024-07-18 11:04:17 -07:00
royjhan b0ea6d9e44 Support api/embed (#208)
* api/embed

* api/embed

* api/embed

* rm legacy
2024-07-18 10:40:30 -07:00
Josh 359c63daa7 integrate tool calls (#213) 2024-07-17 09:40:49 -07:00
Jeffrey Morgan 1a15742705 Update README.md 2024-06-21 22:00:54 -04:00
royjhan ce56f279e8 Add type overloads to methods (#181)
* Add type overloads for chat() method in _client.py

* Overloading

* Fix Overload Overlap

* Fix async chat

* Lint

* Reverse

---------

Co-authored-by: Simon Ottenhaus <simon.ottenhaus@kenbun.de>
2024-06-19 16:10:44 -07:00
royjhan 982d65fea0 Simple Example (#179) 2024-06-18 13:23:07 -07:00
9 changed files with 466 additions and 36 deletions
-18
View File
@@ -2,24 +2,6 @@
The Ollama Python library provides the easiest way to integrate Python 3.8+ projects with [Ollama](https://github.com/ollama/ollama).
## Prerequisites
You need to have a local ollama server running to be able to continue. To do this:
- Download: https://ollama.com/
- Run an LLM: https://ollama.com/library
- Example: `ollama run llama2`
- Example: `ollama run llama2:70b`
Then:
```sh
curl https://ollama.ai/install.sh | sh
ollama serve
```
Next you can go ahead with `ollama-python`.
## Install
```sh
+3 -3
View File
@@ -1,16 +1,16 @@
from ollama import generate
prefix = '''def remove_non_ascii(s: str) -> str:
prompt = '''def remove_non_ascii(s: str) -> str:
""" '''
suffix = """
return result
"""
response = generate(
model='codellama:7b-code',
prompt=f'<PRE> {prefix} <SUF>{suffix} <MID>',
prompt=prompt,
suffix=suffix,
options={
'num_predict': 128,
'temperature': 0,
+31
View File
@@ -0,0 +1,31 @@
from ollama import ps, pull, chat
response = pull('mistral', stream=True)
progress_states = set()
for progress in response:
if progress.get('status') in progress_states:
continue
progress_states.add(progress.get('status'))
print(progress.get('status'))
print('\n')
response = chat('mistral', messages=[{'role': 'user', 'content': 'Hello!'}])
print(response['message']['content'])
print('\n')
response = ps()
name = response['models'][0]['name']
size = response['models'][0]['size']
size_vram = response['models'][0]['size_vram']
if size == size_vram:
print(f'{name}: 100% GPU')
elif not size_vram:
print(f'{name}: 100% CPU')
else:
size_cpu = size - size_vram
cpu_percent = round(size_cpu / size * 100)
print(f'{name}: {cpu_percent}% CPU/{100 - cpu_percent}% GPU')
+3
View File
@@ -0,0 +1,3 @@
# tools
This example demonstrates how to utilize tool calls with an asynchronous Ollama client and the chat endpoint.
+87
View File
@@ -0,0 +1,87 @@
import json
import ollama
import asyncio
# Simulates an API call to get flight times
# In a real application, this would fetch data from a live database or API
def get_flight_times(departure: str, arrival: str) -> str:
flights = {
'NYC-LAX': {'departure': '08:00 AM', 'arrival': '11:30 AM', 'duration': '5h 30m'},
'LAX-NYC': {'departure': '02:00 PM', 'arrival': '10:30 PM', 'duration': '5h 30m'},
'LHR-JFK': {'departure': '10:00 AM', 'arrival': '01:00 PM', 'duration': '8h 00m'},
'JFK-LHR': {'departure': '09:00 PM', 'arrival': '09:00 AM', 'duration': '7h 00m'},
'CDG-DXB': {'departure': '11:00 AM', 'arrival': '08:00 PM', 'duration': '6h 00m'},
'DXB-CDG': {'departure': '03:00 AM', 'arrival': '07:30 AM', 'duration': '7h 30m'},
}
key = f'{departure}-{arrival}'.upper()
return json.dumps(flights.get(key, {'error': 'Flight not found'}))
async def run(model: str):
client = ollama.AsyncClient()
# Initialize conversation with a user query
messages = [{'role': 'user', 'content': 'What is the flight time from New York (NYC) to Los Angeles (LAX)?'}]
# First API call: Send the query and function description to the model
response = await client.chat(
model=model,
messages=messages,
tools=[
{
'type': 'function',
'function': {
'name': 'get_flight_times',
'description': 'Get the flight times between two cities',
'parameters': {
'type': 'object',
'properties': {
'departure': {
'type': 'string',
'description': 'The departure city (airport code)',
},
'arrival': {
'type': 'string',
'description': 'The arrival city (airport code)',
},
},
'required': ['departure', 'arrival'],
},
},
},
],
)
# Add the model's response to the conversation history
messages.append(response['message'])
# Check if the model decided to use the provided function
if not response['message'].get('tool_calls'):
print("The model didn't use the function. Its response was:")
print(response['message']['content'])
return
# Process function calls made by the model
if response['message'].get('tool_calls'):
available_functions = {
'get_flight_times': get_flight_times,
}
for tool in response['message']['tool_calls']:
function_to_call = available_functions[tool['function']['name']]
function_response = function_to_call(tool['function']['arguments']['departure'], tool['function']['arguments']['arrival'])
# Add function response to the conversation
messages.append(
{
'role': 'tool',
'content': function_response,
}
)
# Second API call: Get final response from the model
final_response = await client.chat(model=model, messages=messages)
print(final_response['message']['content'])
# Run the async function
asyncio.run(run('mistral'))
+2
View File
@@ -21,6 +21,7 @@ __all__ = [
'ResponseError',
'generate',
'chat',
'embed',
'embeddings',
'pull',
'push',
@@ -36,6 +37,7 @@ _client = Client()
generate = _client.generate
chat = _client.chat
embed = _client.embed
embeddings = _client.embeddings
pull = _client.pull
push = _client.push
+278 -14
View File
@@ -11,7 +11,7 @@ from copy import deepcopy
from hashlib import sha256
from base64 import b64encode, b64decode
from typing import Any, AnyStr, Union, Optional, Sequence, Mapping, Literal
from typing import Any, AnyStr, Union, Optional, Sequence, Mapping, Literal, overload
import sys
@@ -27,7 +27,7 @@ try:
except metadata.PackageNotFoundError:
__version__ = '0.0.0'
from ollama._types import Message, Options, RequestError, ResponseError
from ollama._types import Message, Options, RequestError, ResponseError, Tool
class BaseClient:
@@ -97,10 +97,45 @@ class Client(BaseClient):
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
return self._stream(*args, **kwargs) if stream else self._request(*args, **kwargs).json()
@overload
def generate(
self,
model: str = '',
prompt: str = '',
suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
stream: Literal[False] = False,
raw: bool = False,
format: Literal['', 'json'] = '',
images: Optional[Sequence[AnyStr]] = None,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
model: str = '',
prompt: str = '',
suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
stream: Literal[True] = True,
raw: bool = False,
format: Literal['', 'json'] = '',
images: Optional[Sequence[AnyStr]] = None,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Iterator[Mapping[str, Any]]: ...
def generate(
self,
model: str = '',
prompt: str = '',
suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
@@ -130,6 +165,7 @@ class Client(BaseClient):
json={
'model': model,
'prompt': prompt,
'suffix': suffix,
'system': system,
'template': template,
'context': context or [],
@@ -143,10 +179,35 @@ class Client(BaseClient):
stream=stream,
)
@overload
def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[False] = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Mapping[str, Any]: ...
@overload
def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[True] = True,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Iterator[Mapping[str, Any]]: ...
def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: bool = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
@@ -168,12 +229,6 @@ class Client(BaseClient):
messages = deepcopy(messages)
for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of Message or dict-like objects')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
if 'content' not in message:
raise RequestError('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]
@@ -183,6 +238,7 @@ class Client(BaseClient):
json={
'model': model,
'messages': messages,
'tools': tools or [],
'stream': stream,
'format': format,
'options': options or {},
@@ -191,6 +247,29 @@ class Client(BaseClient):
stream=stream,
)
def embed(
self,
model: str = '',
input: Union[str, Sequence[AnyStr]] = '',
truncate: bool = True,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Mapping[str, Any]:
if not model:
raise RequestError('must provide a model')
return self._request(
'POST',
'/api/embed',
json={
'model': model,
'input': input,
'truncate': truncate,
'options': options or {},
'keep_alive': keep_alive,
},
).json()
def embeddings(
self,
model: str = '',
@@ -209,6 +288,22 @@ class Client(BaseClient):
},
).json()
@overload
def pull(
self,
model: str,
insecure: bool = False,
stream: Literal[False] = False,
) -> Mapping[str, Any]: ...
@overload
def pull(
self,
model: str,
insecure: bool = False,
stream: Literal[True] = True,
) -> Iterator[Mapping[str, Any]]: ...
def pull(
self,
model: str,
@@ -231,6 +326,22 @@ class Client(BaseClient):
stream=stream,
)
@overload
def push(
self,
model: str,
insecure: bool = False,
stream: Literal[False] = False,
) -> Mapping[str, Any]: ...
@overload
def push(
self,
model: str,
insecure: bool = False,
stream: Literal[True] = True,
) -> Iterator[Mapping[str, Any]]: ...
def push(
self,
model: str,
@@ -253,6 +364,26 @@ class Client(BaseClient):
stream=stream,
)
@overload
def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
quantize: Optional[str] = None,
stream: Literal[False] = False,
) -> Mapping[str, Any]: ...
@overload
def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
quantize: Optional[str] = None,
stream: Literal[True] = True,
) -> Iterator[Mapping[str, Any]]: ...
def create(
self,
model: str,
@@ -386,10 +517,45 @@ class AsyncClient(BaseClient):
response = await self._request(*args, **kwargs)
return response.json()
@overload
async def generate(
self,
model: str = '',
prompt: str = '',
suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
stream: Literal[False] = False,
raw: bool = False,
format: Literal['', 'json'] = '',
images: Optional[Sequence[AnyStr]] = None,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Mapping[str, Any]: ...
@overload
async def generate(
self,
model: str = '',
prompt: str = '',
suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
stream: Literal[True] = True,
raw: bool = False,
format: Literal['', 'json'] = '',
images: Optional[Sequence[AnyStr]] = None,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> AsyncIterator[Mapping[str, Any]]: ...
async def generate(
self,
model: str = '',
prompt: str = '',
suffix: str = '',
system: str = '',
template: str = '',
context: Optional[Sequence[int]] = None,
@@ -418,6 +584,7 @@ class AsyncClient(BaseClient):
json={
'model': model,
'prompt': prompt,
'suffix': suffix,
'system': system,
'template': template,
'context': context or [],
@@ -431,10 +598,35 @@ class AsyncClient(BaseClient):
stream=stream,
)
@overload
async def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[False] = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Mapping[str, Any]: ...
@overload
async def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[True] = True,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> AsyncIterator[Mapping[str, Any]]: ...
async def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: bool = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
@@ -455,12 +647,6 @@ class AsyncClient(BaseClient):
messages = deepcopy(messages)
for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of strings')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
if 'content' not in message:
raise RequestError('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]
@@ -470,6 +656,7 @@ class AsyncClient(BaseClient):
json={
'model': model,
'messages': messages,
'tools': tools or [],
'stream': stream,
'format': format,
'options': options or {},
@@ -478,6 +665,31 @@ class AsyncClient(BaseClient):
stream=stream,
)
async def embed(
self,
model: str = '',
input: Union[str, Sequence[AnyStr]] = '',
truncate: bool = True,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Mapping[str, Any]:
if not model:
raise RequestError('must provide a model')
response = await self._request(
'POST',
'/api/embed',
json={
'model': model,
'input': input,
'truncate': truncate,
'options': options or {},
'keep_alive': keep_alive,
},
)
return response.json()
async def embeddings(
self,
model: str = '',
@@ -498,6 +710,22 @@ class AsyncClient(BaseClient):
return response.json()
@overload
async def pull(
self,
model: str,
insecure: bool = False,
stream: Literal[False] = False,
) -> Mapping[str, Any]: ...
@overload
async def pull(
self,
model: str,
insecure: bool = False,
stream: Literal[True] = True,
) -> AsyncIterator[Mapping[str, Any]]: ...
async def pull(
self,
model: str,
@@ -520,6 +748,22 @@ class AsyncClient(BaseClient):
stream=stream,
)
@overload
async def push(
self,
model: str,
insecure: bool = False,
stream: Literal[False] = False,
) -> Mapping[str, Any]: ...
@overload
async def push(
self,
model: str,
insecure: bool = False,
stream: Literal[True] = True,
) -> AsyncIterator[Mapping[str, Any]]: ...
async def push(
self,
model: str,
@@ -542,6 +786,26 @@ class AsyncClient(BaseClient):
stream=stream,
)
@overload
async def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
quantize: Optional[str] = None,
stream: Literal[False] = False,
) -> Mapping[str, Any]: ...
@overload
async def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
quantize: Optional[str] = None,
stream: Literal[True] = True,
) -> AsyncIterator[Mapping[str, Any]]: ...
async def create(
self,
model: str,
+50 -1
View File
@@ -1,5 +1,5 @@
import json
from typing import Any, TypedDict, Sequence, Literal
from typing import Any, TypedDict, Sequence, Literal, Mapping
import sys
@@ -53,6 +53,27 @@ class GenerateResponse(BaseGenerateResponse):
'Tokenized history up to the point of the response.'
class ToolCallFunction(TypedDict):
"""
Tool call function.
"""
name: str
'Name of the function.'
args: NotRequired[Mapping[str, Any]]
'Arguments of the function.'
class ToolCall(TypedDict):
"""
Model tool calls.
"""
function: ToolCallFunction
'Function to be called.'
class Message(TypedDict):
"""
Chat message.
@@ -76,6 +97,34 @@ class Message(TypedDict):
Valid image formats depend on the model. See the model card for more information.
"""
tool_calls: NotRequired[Sequence[ToolCall]]
"""
Tools calls to be made by the model.
"""
class Property(TypedDict):
type: str
description: str
enum: NotRequired[Sequence[str]] # `enum` is optional and can be a list of strings
class Parameters(TypedDict):
type: str
required: Sequence[str]
properties: Mapping[str, Property]
class ToolFunction(TypedDict):
name: str
description: str
parameters: Parameters
class Tool(TypedDict):
type: str
function: ToolFunction
class ChatResponse(BaseGenerateResponse):
"""
+12
View File
@@ -26,6 +26,7 @@ def test_client_chat(httpserver: HTTPServer):
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': False,
'format': '',
'options': {},
@@ -73,6 +74,7 @@ def test_client_chat_stream(httpserver: HTTPServer):
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': True,
'format': '',
'options': {},
@@ -102,6 +104,7 @@ def test_client_chat_images(httpserver: HTTPServer):
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
},
],
'tools': [],
'stream': False,
'format': '',
'options': {},
@@ -134,6 +137,7 @@ def test_client_generate(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '',
'template': '',
'context': [],
@@ -179,6 +183,7 @@ def test_client_generate_stream(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '',
'template': '',
'context': [],
@@ -207,6 +212,7 @@ def test_client_generate_images(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '',
'template': '',
'context': [],
@@ -522,6 +528,7 @@ async def test_async_client_chat(httpserver: HTTPServer):
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': False,
'format': '',
'options': {},
@@ -560,6 +567,7 @@ async def test_async_client_chat_stream(httpserver: HTTPServer):
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': True,
'format': '',
'options': {},
@@ -590,6 +598,7 @@ async def test_async_client_chat_images(httpserver: HTTPServer):
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
},
],
'tools': [],
'stream': False,
'format': '',
'options': {},
@@ -613,6 +622,7 @@ async def test_async_client_generate(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '',
'template': '',
'context': [],
@@ -653,6 +663,7 @@ async def test_async_client_generate_stream(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '',
'template': '',
'context': [],
@@ -682,6 +693,7 @@ async def test_async_client_generate_images(httpserver: HTTPServer):
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'suffix': '',
'system': '',
'template': '',
'context': [],