mirror of
https://github.com/ollama/ollama-python.git
synced 2026-05-01 11:48:17 +08:00
feat: add max_tool_calls parameter for automatic tool execution in chat
This commit is contained in:
parent
dbccf192ac
commit
5828b8f310
@ -53,9 +53,23 @@ available_functions = {
|
||||
|
||||
async def main():
|
||||
client = ollama.AsyncClient()
|
||||
|
||||
# --- Auto tool execution (max_tool_calls) ---
|
||||
# When max_tool_calls is set, tools are executed automatically in a loop.
|
||||
# The model calls tools, results are fed back, and the final response is returned.
|
||||
print('\n--- Auto tool execution ---')
|
||||
response: ChatResponse = await client.chat(
|
||||
'llama3.1',
|
||||
'qwen3.5:4b',
|
||||
messages=messages,
|
||||
tools=[add_two_numbers, subtract_two_numbers_tool],
|
||||
max_tool_calls=10
|
||||
)
|
||||
print('Response:', response.message.content)
|
||||
|
||||
# --- Manual tool handling ---
|
||||
# Without max_tool_calls, tool calls are returned for you to handle manually.
|
||||
print('\n--- Manual tool handling ---')
|
||||
response: ChatResponse = await client.chat(
|
||||
'qwen3.5:4b',
|
||||
messages=messages,
|
||||
tools=[add_two_numbers, subtract_two_numbers_tool],
|
||||
)
|
||||
@ -79,7 +93,7 @@ async def main():
|
||||
messages.append({'role': 'tool', 'content': str(output), 'tool_name': tool.function.name})
|
||||
|
||||
# Get final response from model with function outputs
|
||||
final_response = await client.chat('llama3.1', messages=messages)
|
||||
final_response = await client.chat('qwen3.5:4b', messages=messages)
|
||||
print('Final response:', final_response.message.content)
|
||||
|
||||
else:
|
||||
|
||||
@ -319,6 +319,7 @@ class Client(BaseClient):
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
max_tool_calls: Optional[int] = None,
|
||||
) -> ChatResponse: ...
|
||||
|
||||
@overload
|
||||
@ -335,6 +336,7 @@ class Client(BaseClient):
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
max_tool_calls: Optional[int] = None,
|
||||
) -> Iterator[ChatResponse]: ...
|
||||
|
||||
def chat(
|
||||
@ -350,6 +352,7 @@ class Client(BaseClient):
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
max_tool_calls: Optional[int] = None,
|
||||
) -> Union[ChatResponse, Iterator[ChatResponse]]:
|
||||
"""
|
||||
Create a chat response using the requested model.
|
||||
@ -361,6 +364,8 @@ class Client(BaseClient):
|
||||
For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
|
||||
stream: Whether to stream the response.
|
||||
format: The format of the response.
|
||||
max_tool_calls: If set to a positive int, automatically execute tool calls in a loop
|
||||
up to this many iterations. None (default) disables auto-execution.
|
||||
|
||||
Example:
|
||||
def add_two_numbers(a: int, b: int) -> int:
|
||||
@ -376,7 +381,11 @@ class Client(BaseClient):
|
||||
'''
|
||||
return a + b
|
||||
|
||||
client.chat(model='llama3.2', tools=[add_two_numbers], messages=[...])
|
||||
# Manual tool handling:
|
||||
client.chat(model='qwen3.5:4b', tools=[add_two_numbers], messages=[...])
|
||||
|
||||
# Auto tool execution (max 10 iterations):
|
||||
client.chat(model='qwen3.5:4b', tools=[add_two_numbers], messages=[...], max_tool_calls=10)
|
||||
|
||||
Raises `RequestError` if a model is not provided.
|
||||
|
||||
@ -384,24 +393,47 @@ class Client(BaseClient):
|
||||
|
||||
Returns `ChatResponse` if `stream` is `False`, otherwise returns a `ChatResponse` generator.
|
||||
"""
|
||||
return self._request(
|
||||
ChatResponse,
|
||||
'POST',
|
||||
'/api/chat',
|
||||
json=ChatRequest(
|
||||
model=model,
|
||||
messages=list(_copy_messages(messages)),
|
||||
tools=list(_copy_tools(tools)),
|
||||
# MARK: standard path (no auto tool execution)
|
||||
if stream or not max_tool_calls:
|
||||
return self._request(
|
||||
ChatResponse,
|
||||
'POST',
|
||||
'/api/chat',
|
||||
json=ChatRequest(
|
||||
model=model,
|
||||
messages=list(_copy_messages(messages)),
|
||||
tools=list(_copy_tools(tools)),
|
||||
stream=stream,
|
||||
think=think,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
format=format,
|
||||
options=options,
|
||||
keep_alive=keep_alive,
|
||||
).model_dump(exclude_none=True),
|
||||
stream=stream,
|
||||
think=think,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
format=format,
|
||||
options=options,
|
||||
keep_alive=keep_alive,
|
||||
).model_dump(exclude_none=True),
|
||||
stream=stream,
|
||||
)
|
||||
)
|
||||
|
||||
# MARK: auto tool execution loop
|
||||
tool_map = {f.__name__: f for f in (tools or []) if callable(f)}
|
||||
msgs = list(messages or [])
|
||||
|
||||
for _ in range(max_tool_calls):
|
||||
response = self._request(
|
||||
ChatResponse, 'POST', '/api/chat',
|
||||
json=ChatRequest(
|
||||
model=model, messages=list(_copy_messages(msgs)), tools=list(_copy_tools(tools)),
|
||||
stream=False, think=think, format=format, options=options, keep_alive=keep_alive,
|
||||
).model_dump(exclude_none=True), stream=False,
|
||||
)
|
||||
if not response.message.tool_calls:
|
||||
return response
|
||||
msgs.append(response.message)
|
||||
for tc in response.message.tool_calls:
|
||||
output = _exec_tool(tool_map, tc)
|
||||
msgs.append({'role': 'tool', 'content': output, 'tool_name': tc.function.name})
|
||||
|
||||
raise RuntimeError(f'Tool calling exceeded {max_tool_calls} iterations')
|
||||
|
||||
def embed(
|
||||
self,
|
||||
@ -951,6 +983,7 @@ class AsyncClient(BaseClient):
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
max_tool_calls: Optional[int] = None,
|
||||
) -> ChatResponse: ...
|
||||
|
||||
@overload
|
||||
@ -967,6 +1000,7 @@ class AsyncClient(BaseClient):
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
max_tool_calls: Optional[int] = None,
|
||||
) -> AsyncIterator[ChatResponse]: ...
|
||||
|
||||
async def chat(
|
||||
@ -982,6 +1016,7 @@ class AsyncClient(BaseClient):
|
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||
keep_alive: Optional[Union[float, str]] = None,
|
||||
max_tool_calls: Optional[int] = None,
|
||||
) -> Union[ChatResponse, AsyncIterator[ChatResponse]]:
|
||||
"""
|
||||
Create a chat response using the requested model.
|
||||
@ -993,6 +1028,8 @@ class AsyncClient(BaseClient):
|
||||
For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
|
||||
stream: Whether to stream the response.
|
||||
format: The format of the response.
|
||||
max_tool_calls: If set to a positive int, automatically execute tool calls in a loop
|
||||
up to this many iterations. None (default) disables auto-execution.
|
||||
|
||||
Example:
|
||||
def add_two_numbers(a: int, b: int) -> int:
|
||||
@ -1008,7 +1045,11 @@ class AsyncClient(BaseClient):
|
||||
'''
|
||||
return a + b
|
||||
|
||||
await client.chat(model='llama3.2', tools=[add_two_numbers], messages=[...])
|
||||
# Manual tool handling:
|
||||
await client.chat(model='qwen3.5:4b', tools=[add_two_numbers], messages=[...])
|
||||
|
||||
# Auto tool execution (max 10 iterations):
|
||||
await client.chat(model='qwen3.5:4b', tools=[add_two_numbers], messages=[...], max_tool_calls=10)
|
||||
|
||||
Raises `RequestError` if a model is not provided.
|
||||
|
||||
@ -1016,25 +1057,47 @@ class AsyncClient(BaseClient):
|
||||
|
||||
Returns `ChatResponse` if `stream` is `False`, otherwise returns an asynchronous `ChatResponse` generator.
|
||||
"""
|
||||
|
||||
return await self._request(
|
||||
ChatResponse,
|
||||
'POST',
|
||||
'/api/chat',
|
||||
json=ChatRequest(
|
||||
model=model,
|
||||
messages=list(_copy_messages(messages)),
|
||||
tools=list(_copy_tools(tools)),
|
||||
# MARK: standard path (no auto tool execution)
|
||||
if stream or not max_tool_calls:
|
||||
return await self._request(
|
||||
ChatResponse,
|
||||
'POST',
|
||||
'/api/chat',
|
||||
json=ChatRequest(
|
||||
model=model,
|
||||
messages=list(_copy_messages(messages)),
|
||||
tools=list(_copy_tools(tools)),
|
||||
stream=stream,
|
||||
think=think,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
format=format,
|
||||
options=options,
|
||||
keep_alive=keep_alive,
|
||||
).model_dump(exclude_none=True),
|
||||
stream=stream,
|
||||
think=think,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
format=format,
|
||||
options=options,
|
||||
keep_alive=keep_alive,
|
||||
).model_dump(exclude_none=True),
|
||||
stream=stream,
|
||||
)
|
||||
)
|
||||
|
||||
# MARK: auto tool execution loop
|
||||
tool_map = {f.__name__: f for f in (tools or []) if callable(f)}
|
||||
msgs = list(messages or [])
|
||||
|
||||
for _ in range(max_tool_calls):
|
||||
response = await self._request(
|
||||
ChatResponse, 'POST', '/api/chat',
|
||||
json=ChatRequest(
|
||||
model=model, messages=list(_copy_messages(msgs)), tools=list(_copy_tools(tools)),
|
||||
stream=False, think=think, format=format, options=options, keep_alive=keep_alive,
|
||||
).model_dump(exclude_none=True), stream=False,
|
||||
)
|
||||
if not response.message.tool_calls:
|
||||
return response
|
||||
msgs.append(response.message)
|
||||
for tc in response.message.tool_calls:
|
||||
output = _exec_tool(tool_map, tc)
|
||||
msgs.append({'role': 'tool', 'content': output, 'tool_name': tc.function.name})
|
||||
|
||||
raise RuntimeError(f'Tool calling exceeded {max_tool_calls} iterations')
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
@ -1330,6 +1393,15 @@ def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable
|
||||
yield convert_function_to_tool(unprocessed_tool) if callable(unprocessed_tool) else Tool.model_validate(unprocessed_tool)
|
||||
|
||||
|
||||
def _exec_tool(tool_map: dict, tc: Message.ToolCall) -> str:
|
||||
"""Execute a tool call, return result as string."""
|
||||
fn = tool_map.get(tc.function.name)
|
||||
if not fn:
|
||||
return json.dumps({'error': f'Tool {tc.function.name} not found'})
|
||||
output = fn(**tc.function.arguments)
|
||||
return output if isinstance(output, str) else json.dumps(output, default=str)
|
||||
|
||||
|
||||
def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:
|
||||
if isinstance(s, (str, Path)):
|
||||
try:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user