mirror of
https://github.com/ollama/ollama-python.git
synced 2026-06-16 21:24:52 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 33c4b61ff9 | |||
| b0ea6d9e44 | |||
| 359c63daa7 | |||
| 1a15742705 | |||
| ce56f279e8 | |||
| 982d65fea0 |
@@ -2,24 +2,6 @@
|
||||
|
||||
The Ollama Python library provides the easiest way to integrate Python 3.8+ projects with [Ollama](https://github.com/ollama/ollama).
|
||||
|
||||
## Prerequisites
|
||||
|
||||
You need to have a local ollama server running to be able to continue. To do this:
|
||||
|
||||
- Download: https://ollama.com/
|
||||
- Run an LLM: https://ollama.com/library
|
||||
- Example: `ollama run llama2`
|
||||
- Example: `ollama run llama2:70b`
|
||||
|
||||
Then:
|
||||
|
||||
```sh
|
||||
curl https://ollama.ai/install.sh | sh
|
||||
ollama serve
|
||||
```
|
||||
|
||||
Next you can go ahead with `ollama-python`.
|
||||
|
||||
## Install
|
||||
|
||||
```sh
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from ollama import generate
|
||||
|
||||
prefix = '''def remove_non_ascii(s: str) -> str:
|
||||
prompt = '''def remove_non_ascii(s: str) -> str:
|
||||
""" '''
|
||||
|
||||
suffix = """
|
||||
return result
|
||||
"""
|
||||
|
||||
|
||||
response = generate(
|
||||
model='codellama:7b-code',
|
||||
prompt=f'<PRE> {prefix} <SUF>{suffix} <MID>',
|
||||
prompt=prompt,
|
||||
suffix=suffix,
|
||||
options={
|
||||
'num_predict': 128,
|
||||
'temperature': 0,
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
from ollama import ps, pull, chat
|
||||
|
||||
response = pull('mistral', stream=True)
|
||||
progress_states = set()
|
||||
for progress in response:
|
||||
if progress.get('status') in progress_states:
|
||||
continue
|
||||
progress_states.add(progress.get('status'))
|
||||
print(progress.get('status'))
|
||||
|
||||
print('\n')
|
||||
|
||||
response = chat('mistral', messages=[{'role': 'user', 'content': 'Hello!'}])
|
||||
print(response['message']['content'])
|
||||
|
||||
print('\n')
|
||||
|
||||
response = ps()
|
||||
|
||||
name = response['models'][0]['name']
|
||||
size = response['models'][0]['size']
|
||||
size_vram = response['models'][0]['size_vram']
|
||||
|
||||
if size == size_vram:
|
||||
print(f'{name}: 100% GPU')
|
||||
elif not size_vram:
|
||||
print(f'{name}: 100% CPU')
|
||||
else:
|
||||
size_cpu = size - size_vram
|
||||
cpu_percent = round(size_cpu / size * 100)
|
||||
print(f'{name}: {cpu_percent}% CPU/{100 - cpu_percent}% GPU')
|
||||
@@ -0,0 +1,3 @@
|
||||
# tools
|
||||
|
||||
This example demonstrates how to utilize tool calls with an asynchronous Ollama client and the chat endpoint.
|
||||
@@ -0,0 +1,87 @@
|
||||
import json
|
||||
import ollama
|
||||
import asyncio
|
||||
|
||||
|
||||
# Simulates an API call to get flight times
|
||||
# In a real application, this would fetch data from a live database or API
|
||||
def get_flight_times(departure: str, arrival: str) -> str:
|
||||
flights = {
|
||||
'NYC-LAX': {'departure': '08:00 AM', 'arrival': '11:30 AM', 'duration': '5h 30m'},
|
||||
'LAX-NYC': {'departure': '02:00 PM', 'arrival': '10:30 PM', 'duration': '5h 30m'},
|
||||
'LHR-JFK': {'departure': '10:00 AM', 'arrival': '01:00 PM', 'duration': '8h 00m'},
|
||||
'JFK-LHR': {'departure': '09:00 PM', 'arrival': '09:00 AM', 'duration': '7h 00m'},
|
||||
'CDG-DXB': {'departure': '11:00 AM', 'arrival': '08:00 PM', 'duration': '6h 00m'},
|
||||
'DXB-CDG': {'departure': '03:00 AM', 'arrival': '07:30 AM', 'duration': '7h 30m'},
|
||||
}
|
||||
|
||||
key = f'{departure}-{arrival}'.upper()
|
||||
return json.dumps(flights.get(key, {'error': 'Flight not found'}))
|
||||
|
||||
|
||||
async def run(model: str):
|
||||
client = ollama.AsyncClient()
|
||||
# Initialize conversation with a user query
|
||||
messages = [{'role': 'user', 'content': 'What is the flight time from New York (NYC) to Los Angeles (LAX)?'}]
|
||||
|
||||
# First API call: Send the query and function description to the model
|
||||
response = await client.chat(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=[
|
||||
{
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'get_flight_times',
|
||||
'description': 'Get the flight times between two cities',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'departure': {
|
||||
'type': 'string',
|
||||
'description': 'The departure city (airport code)',
|
||||
},
|
||||
'arrival': {
|
||||
'type': 'string',
|
||||
'description': 'The arrival city (airport code)',
|
||||
},
|
||||
},
|
||||
'required': ['departure', 'arrival'],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Add the model's response to the conversation history
|
||||
messages.append(response['message'])
|
||||
|
||||
# Check if the model decided to use the provided function
|
||||
if not response['message'].get('tool_calls'):
|
||||
print("The model didn't use the function. Its response was:")
|
||||
print(response['message']['content'])
|
||||
return
|
||||
|
||||
# Process function calls made by the model
|
||||
if response['message'].get('tool_calls'):
|
||||
available_functions = {
|
||||
'get_flight_times': get_flight_times,
|
||||
}
|
||||
for tool in response['message']['tool_calls']:
|
||||
function_to_call = available_functions[tool['function']['name']]
|
||||
function_response = function_to_call(tool['function']['arguments']['departure'], tool['function']['arguments']['arrival'])
|
||||
# Add function response to the conversation
|
||||
messages.append(
|
||||
{
|
||||
'role': 'tool',
|
||||
'content': function_response,
|
||||
}
|
||||
)
|
||||
|
||||
# Second API call: Get final response from the model
|
||||
final_response = await client.chat(model=model, messages=messages)
|
||||
print(final_response['message']['content'])
|
||||
|
||||
|
||||
# Run the async function
|
||||
asyncio.run(run('mistral'))
|
||||
@@ -21,6 +21,7 @@ __all__ = [
|
||||
'ResponseError',
|
||||
'generate',
|
||||
'chat',
|
||||
'embed',
|
||||
'embeddings',
|
||||
'pull',
|
||||
'push',
|
||||
@@ -36,6 +37,7 @@ _client = Client()
|
||||
|
||||
generate = _client.generate
|
||||
chat = _client.chat
|
||||
embed = _client.embed
|
||||
embeddings = _client.embeddings
|
||||
pull = _client.pull
|
||||
push = _client.push
|
||||
|
||||
+278
-14
@@ -11,7 +11,7 @@ from copy import deepcopy
|
||||
from hashlib import sha256
|
||||
from base64 import b64encode, b64decode
|
||||
|
||||
from typing import Any, AnyStr, Union, Optional, Sequence, Mapping, Literal
|
||||
from typing import Any, AnyStr, Union, Optional, Sequence, Mapping, Literal, overload
|
||||
|
||||
import sys
|
||||
|
||||
@@ -27,7 +27,7 @@ try:
|
||||
except metadata.PackageNotFoundError:
|
||||
__version__ = '0.0.0'
|
||||
|
||||
from ollama._types import Message, Options, RequestError, ResponseError
|
||||
from ollama._types import Message, Options, RequestError, ResponseError, Tool
|
||||
|
||||
|
||||
class BaseClient:
|
||||
@@ -97,10 +97,45 @@ class Client(BaseClient):
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
return self._stream(*args, **kwargs) if stream else self._request(*args, **kwargs).json()
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
model: str = '',
|
||||
prompt: str = '',
|
||||
suffix: str = '',
|
||||
system: str = '',
|
||||
template: str = '',
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: Literal[False] = False,
|
||||
raw: bool = False,
|
||||
format: Literal['', 'json'] = '',
|
||||
images: Optional[Sequence[AnyStr]] = None,
|
||||
options: Optional[Options] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
model: str = '',
|
||||
prompt: str = '',
|
||||
suffix: str = '',
|
||||
system: str = '',
|
||||
template: str = '',
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: Literal[True] = True,
|
||||
raw: bool = False,
|
||||
format: Literal['', 'json'] = '',
|
||||
images: Optional[Sequence[AnyStr]] = None,
|
||||
options: Optional[Options] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Iterator[Mapping[str, Any]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
model: str = '',
|
||||
prompt: str = '',
|
||||
suffix: str = '',
|
||||
system: str = '',
|
||||
template: str = '',
|
||||
context: Optional[Sequence[int]] = None,
|
||||
@@ -130,6 +165,7 @@ class Client(BaseClient):
|
||||
json={
|
||||
'model': model,
|
||||
'prompt': prompt,
|
||||
'suffix': suffix,
|
||||
'system': system,
|
||||
'template': template,
|
||||
'context': context or [],
|
||||
@@ -143,10 +179,35 @@ class Client(BaseClient):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@overload
|
||||
def chat(
|
||||
self,
|
||||
model: str = '',
|
||||
messages: Optional[Sequence[Message]] = None,
|
||||
tools: Optional[Sequence[Tool]] = None,
|
||||
stream: Literal[False] = False,
|
||||
format: Literal['', 'json'] = '',
|
||||
options: Optional[Options] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def chat(
|
||||
self,
|
||||
model: str = '',
|
||||
messages: Optional[Sequence[Message]] = None,
|
||||
tools: Optional[Sequence[Tool]] = None,
|
||||
stream: Literal[True] = True,
|
||||
format: Literal['', 'json'] = '',
|
||||
options: Optional[Options] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Iterator[Mapping[str, Any]]: ...
|
||||
|
||||
def chat(
|
||||
self,
|
||||
model: str = '',
|
||||
messages: Optional[Sequence[Message]] = None,
|
||||
tools: Optional[Sequence[Tool]] = None,
|
||||
stream: bool = False,
|
||||
format: Literal['', 'json'] = '',
|
||||
options: Optional[Options] = None,
|
||||
@@ -168,12 +229,6 @@ class Client(BaseClient):
|
||||
messages = deepcopy(messages)
|
||||
|
||||
for message in messages or []:
|
||||
if not isinstance(message, dict):
|
||||
raise TypeError('messages must be a list of Message or dict-like objects')
|
||||
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
|
||||
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
|
||||
if 'content' not in message:
|
||||
raise RequestError('messages must contain content')
|
||||
if images := message.get('images'):
|
||||
message['images'] = [_encode_image(image) for image in images]
|
||||
|
||||
@@ -183,6 +238,7 @@ class Client(BaseClient):
|
||||
json={
|
||||
'model': model,
|
||||
'messages': messages,
|
||||
'tools': tools or [],
|
||||
'stream': stream,
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
@@ -191,6 +247,29 @@ class Client(BaseClient):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def embed(
|
||||
self,
|
||||
model: str = '',
|
||||
input: Union[str, Sequence[AnyStr]] = '',
|
||||
truncate: bool = True,
|
||||
options: Optional[Options] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
if not model:
|
||||
raise RequestError('must provide a model')
|
||||
|
||||
return self._request(
|
||||
'POST',
|
||||
'/api/embed',
|
||||
json={
|
||||
'model': model,
|
||||
'input': input,
|
||||
'truncate': truncate,
|
||||
'options': options or {},
|
||||
'keep_alive': keep_alive,
|
||||
},
|
||||
).json()
|
||||
|
||||
def embeddings(
|
||||
self,
|
||||
model: str = '',
|
||||
@@ -209,6 +288,22 @@ class Client(BaseClient):
|
||||
},
|
||||
).json()
|
||||
|
||||
@overload
|
||||
def pull(
|
||||
self,
|
||||
model: str,
|
||||
insecure: bool = False,
|
||||
stream: Literal[False] = False,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def pull(
|
||||
self,
|
||||
model: str,
|
||||
insecure: bool = False,
|
||||
stream: Literal[True] = True,
|
||||
) -> Iterator[Mapping[str, Any]]: ...
|
||||
|
||||
def pull(
|
||||
self,
|
||||
model: str,
|
||||
@@ -231,6 +326,22 @@ class Client(BaseClient):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@overload
|
||||
def push(
|
||||
self,
|
||||
model: str,
|
||||
insecure: bool = False,
|
||||
stream: Literal[False] = False,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def push(
|
||||
self,
|
||||
model: str,
|
||||
insecure: bool = False,
|
||||
stream: Literal[True] = True,
|
||||
) -> Iterator[Mapping[str, Any]]: ...
|
||||
|
||||
def push(
|
||||
self,
|
||||
model: str,
|
||||
@@ -253,6 +364,26 @@ class Client(BaseClient):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@overload
|
||||
def create(
|
||||
self,
|
||||
model: str,
|
||||
path: Optional[Union[str, PathLike]] = None,
|
||||
modelfile: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
stream: Literal[False] = False,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def create(
|
||||
self,
|
||||
model: str,
|
||||
path: Optional[Union[str, PathLike]] = None,
|
||||
modelfile: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
stream: Literal[True] = True,
|
||||
) -> Iterator[Mapping[str, Any]]: ...
|
||||
|
||||
def create(
|
||||
self,
|
||||
model: str,
|
||||
@@ -386,10 +517,45 @@ class AsyncClient(BaseClient):
|
||||
response = await self._request(*args, **kwargs)
|
||||
return response.json()
|
||||
|
||||
@overload
|
||||
async def generate(
|
||||
self,
|
||||
model: str = '',
|
||||
prompt: str = '',
|
||||
suffix: str = '',
|
||||
system: str = '',
|
||||
template: str = '',
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: Literal[False] = False,
|
||||
raw: bool = False,
|
||||
format: Literal['', 'json'] = '',
|
||||
images: Optional[Sequence[AnyStr]] = None,
|
||||
options: Optional[Options] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
async def generate(
|
||||
self,
|
||||
model: str = '',
|
||||
prompt: str = '',
|
||||
suffix: str = '',
|
||||
system: str = '',
|
||||
template: str = '',
|
||||
context: Optional[Sequence[int]] = None,
|
||||
stream: Literal[True] = True,
|
||||
raw: bool = False,
|
||||
format: Literal['', 'json'] = '',
|
||||
images: Optional[Sequence[AnyStr]] = None,
|
||||
options: Optional[Options] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> AsyncIterator[Mapping[str, Any]]: ...
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
model: str = '',
|
||||
prompt: str = '',
|
||||
suffix: str = '',
|
||||
system: str = '',
|
||||
template: str = '',
|
||||
context: Optional[Sequence[int]] = None,
|
||||
@@ -418,6 +584,7 @@ class AsyncClient(BaseClient):
|
||||
json={
|
||||
'model': model,
|
||||
'prompt': prompt,
|
||||
'suffix': suffix,
|
||||
'system': system,
|
||||
'template': template,
|
||||
'context': context or [],
|
||||
@@ -431,10 +598,35 @@ class AsyncClient(BaseClient):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@overload
|
||||
async def chat(
|
||||
self,
|
||||
model: str = '',
|
||||
messages: Optional[Sequence[Message]] = None,
|
||||
tools: Optional[Sequence[Tool]] = None,
|
||||
stream: Literal[False] = False,
|
||||
format: Literal['', 'json'] = '',
|
||||
options: Optional[Options] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
async def chat(
|
||||
self,
|
||||
model: str = '',
|
||||
messages: Optional[Sequence[Message]] = None,
|
||||
tools: Optional[Sequence[Tool]] = None,
|
||||
stream: Literal[True] = True,
|
||||
format: Literal['', 'json'] = '',
|
||||
options: Optional[Options] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> AsyncIterator[Mapping[str, Any]]: ...
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
model: str = '',
|
||||
messages: Optional[Sequence[Message]] = None,
|
||||
tools: Optional[Sequence[Tool]] = None,
|
||||
stream: bool = False,
|
||||
format: Literal['', 'json'] = '',
|
||||
options: Optional[Options] = None,
|
||||
@@ -455,12 +647,6 @@ class AsyncClient(BaseClient):
|
||||
messages = deepcopy(messages)
|
||||
|
||||
for message in messages or []:
|
||||
if not isinstance(message, dict):
|
||||
raise TypeError('messages must be a list of strings')
|
||||
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
|
||||
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
|
||||
if 'content' not in message:
|
||||
raise RequestError('messages must contain content')
|
||||
if images := message.get('images'):
|
||||
message['images'] = [_encode_image(image) for image in images]
|
||||
|
||||
@@ -470,6 +656,7 @@ class AsyncClient(BaseClient):
|
||||
json={
|
||||
'model': model,
|
||||
'messages': messages,
|
||||
'tools': tools or [],
|
||||
'stream': stream,
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
@@ -478,6 +665,31 @@ class AsyncClient(BaseClient):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
model: str = '',
|
||||
input: Union[str, Sequence[AnyStr]] = '',
|
||||
truncate: bool = True,
|
||||
options: Optional[Options] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
if not model:
|
||||
raise RequestError('must provide a model')
|
||||
|
||||
response = await self._request(
|
||||
'POST',
|
||||
'/api/embed',
|
||||
json={
|
||||
'model': model,
|
||||
'input': input,
|
||||
'truncate': truncate,
|
||||
'options': options or {},
|
||||
'keep_alive': keep_alive,
|
||||
},
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str = '',
|
||||
@@ -498,6 +710,22 @@ class AsyncClient(BaseClient):
|
||||
|
||||
return response.json()
|
||||
|
||||
@overload
|
||||
async def pull(
|
||||
self,
|
||||
model: str,
|
||||
insecure: bool = False,
|
||||
stream: Literal[False] = False,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
async def pull(
|
||||
self,
|
||||
model: str,
|
||||
insecure: bool = False,
|
||||
stream: Literal[True] = True,
|
||||
) -> AsyncIterator[Mapping[str, Any]]: ...
|
||||
|
||||
async def pull(
|
||||
self,
|
||||
model: str,
|
||||
@@ -520,6 +748,22 @@ class AsyncClient(BaseClient):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@overload
|
||||
async def push(
|
||||
self,
|
||||
model: str,
|
||||
insecure: bool = False,
|
||||
stream: Literal[False] = False,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
async def push(
|
||||
self,
|
||||
model: str,
|
||||
insecure: bool = False,
|
||||
stream: Literal[True] = True,
|
||||
) -> AsyncIterator[Mapping[str, Any]]: ...
|
||||
|
||||
async def push(
|
||||
self,
|
||||
model: str,
|
||||
@@ -542,6 +786,26 @@ class AsyncClient(BaseClient):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@overload
|
||||
async def create(
|
||||
self,
|
||||
model: str,
|
||||
path: Optional[Union[str, PathLike]] = None,
|
||||
modelfile: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
stream: Literal[False] = False,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
async def create(
|
||||
self,
|
||||
model: str,
|
||||
path: Optional[Union[str, PathLike]] = None,
|
||||
modelfile: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
stream: Literal[True] = True,
|
||||
) -> AsyncIterator[Mapping[str, Any]]: ...
|
||||
|
||||
async def create(
|
||||
self,
|
||||
model: str,
|
||||
|
||||
+50
-1
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, TypedDict, Sequence, Literal
|
||||
from typing import Any, TypedDict, Sequence, Literal, Mapping
|
||||
|
||||
import sys
|
||||
|
||||
@@ -53,6 +53,27 @@ class GenerateResponse(BaseGenerateResponse):
|
||||
'Tokenized history up to the point of the response.'
|
||||
|
||||
|
||||
class ToolCallFunction(TypedDict):
|
||||
"""
|
||||
Tool call function.
|
||||
"""
|
||||
|
||||
name: str
|
||||
'Name of the function.'
|
||||
|
||||
args: NotRequired[Mapping[str, Any]]
|
||||
'Arguments of the function.'
|
||||
|
||||
|
||||
class ToolCall(TypedDict):
|
||||
"""
|
||||
Model tool calls.
|
||||
"""
|
||||
|
||||
function: ToolCallFunction
|
||||
'Function to be called.'
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
"""
|
||||
Chat message.
|
||||
@@ -76,6 +97,34 @@ class Message(TypedDict):
|
||||
Valid image formats depend on the model. See the model card for more information.
|
||||
"""
|
||||
|
||||
tool_calls: NotRequired[Sequence[ToolCall]]
|
||||
"""
|
||||
Tools calls to be made by the model.
|
||||
"""
|
||||
|
||||
|
||||
class Property(TypedDict):
|
||||
type: str
|
||||
description: str
|
||||
enum: NotRequired[Sequence[str]] # `enum` is optional and can be a list of strings
|
||||
|
||||
|
||||
class Parameters(TypedDict):
|
||||
type: str
|
||||
required: Sequence[str]
|
||||
properties: Mapping[str, Property]
|
||||
|
||||
|
||||
class ToolFunction(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
parameters: Parameters
|
||||
|
||||
|
||||
class Tool(TypedDict):
|
||||
type: str
|
||||
function: ToolFunction
|
||||
|
||||
|
||||
class ChatResponse(BaseGenerateResponse):
|
||||
"""
|
||||
|
||||
@@ -26,6 +26,7 @@ def test_client_chat(httpserver: HTTPServer):
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'tools': [],
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
@@ -73,6 +74,7 @@ def test_client_chat_stream(httpserver: HTTPServer):
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'tools': [],
|
||||
'stream': True,
|
||||
'format': '',
|
||||
'options': {},
|
||||
@@ -102,6 +104,7 @@ def test_client_chat_images(httpserver: HTTPServer):
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
},
|
||||
],
|
||||
'tools': [],
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
@@ -134,6 +137,7 @@ def test_client_generate(httpserver: HTTPServer):
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'suffix': '',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
@@ -179,6 +183,7 @@ def test_client_generate_stream(httpserver: HTTPServer):
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'suffix': '',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
@@ -207,6 +212,7 @@ def test_client_generate_images(httpserver: HTTPServer):
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'suffix': '',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
@@ -522,6 +528,7 @@ async def test_async_client_chat(httpserver: HTTPServer):
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'tools': [],
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
@@ -560,6 +567,7 @@ async def test_async_client_chat_stream(httpserver: HTTPServer):
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'tools': [],
|
||||
'stream': True,
|
||||
'format': '',
|
||||
'options': {},
|
||||
@@ -590,6 +598,7 @@ async def test_async_client_chat_images(httpserver: HTTPServer):
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
},
|
||||
],
|
||||
'tools': [],
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
@@ -613,6 +622,7 @@ async def test_async_client_generate(httpserver: HTTPServer):
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'suffix': '',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
@@ -653,6 +663,7 @@ async def test_async_client_generate_stream(httpserver: HTTPServer):
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'suffix': '',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
@@ -682,6 +693,7 @@ async def test_async_client_generate_images(httpserver: HTTPServer):
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'suffix': '',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
|
||||
Reference in New Issue
Block a user