feat: add max_tool_calls parameter for automatic tool execution in chat

This commit is contained in:
HamzaYslmn 2026-03-08 20:48:22 +03:00
parent dbccf192ac
commit 5828b8f310
2 changed files with 126 additions and 40 deletions

View File

@ -53,9 +53,23 @@ available_functions = {
async def main():
client = ollama.AsyncClient()
# --- Auto tool execution (max_tool_calls) ---
# When max_tool_calls is set, tools are executed automatically in a loop.
# The model calls tools, results are fed back, and the final response is returned.
print('\n--- Auto tool execution ---')
response: ChatResponse = await client.chat(
'llama3.1',
'qwen3.5:4b',
messages=messages,
tools=[add_two_numbers, subtract_two_numbers_tool],
max_tool_calls=10
)
print('Response:', response.message.content)
# --- Manual tool handling ---
# Without max_tool_calls, tool calls are returned for you to handle manually.
print('\n--- Manual tool handling ---')
response: ChatResponse = await client.chat(
'qwen3.5:4b',
messages=messages,
tools=[add_two_numbers, subtract_two_numbers_tool],
)
@ -79,7 +93,7 @@ async def main():
messages.append({'role': 'tool', 'content': str(output), 'tool_name': tool.function.name})
# Get final response from model with function outputs
final_response = await client.chat('llama3.1', messages=messages)
final_response = await client.chat('qwen3.5:4b', messages=messages)
print('Final response:', final_response.message.content)
else:

View File

@ -319,6 +319,7 @@ class Client(BaseClient):
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
max_tool_calls: Optional[int] = None,
) -> ChatResponse: ...
@overload
@ -335,6 +336,7 @@ class Client(BaseClient):
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
max_tool_calls: Optional[int] = None,
) -> Iterator[ChatResponse]: ...
def chat(
@ -350,6 +352,7 @@ class Client(BaseClient):
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
max_tool_calls: Optional[int] = None,
) -> Union[ChatResponse, Iterator[ChatResponse]]:
"""
Create a chat response using the requested model.
@ -361,6 +364,8 @@ class Client(BaseClient):
For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
stream: Whether to stream the response.
format: The format of the response.
max_tool_calls: If set to a positive int, automatically execute tool calls in a loop
up to this many iterations. None (default) disables auto-execution.
Example:
def add_two_numbers(a: int, b: int) -> int:
@ -376,7 +381,11 @@ class Client(BaseClient):
'''
return a + b
client.chat(model='llama3.2', tools=[add_two_numbers], messages=[...])
# Manual tool handling:
client.chat(model='qwen3.5:4b', tools=[add_two_numbers], messages=[...])
# Auto tool execution (max 10 iterations):
client.chat(model='qwen3.5:4b', tools=[add_two_numbers], messages=[...], max_tool_calls=10)
Raises `RequestError` if a model is not provided.
@ -384,24 +393,47 @@ class Client(BaseClient):
Returns `ChatResponse` if `stream` is `False`, otherwise returns a `ChatResponse` generator.
"""
return self._request(
ChatResponse,
'POST',
'/api/chat',
json=ChatRequest(
model=model,
messages=list(_copy_messages(messages)),
tools=list(_copy_tools(tools)),
# MARK: standard path (no auto tool execution)
if stream or not max_tool_calls:
return self._request(
ChatResponse,
'POST',
'/api/chat',
json=ChatRequest(
model=model,
messages=list(_copy_messages(messages)),
tools=list(_copy_tools(tools)),
stream=stream,
think=think,
logprobs=logprobs,
top_logprobs=top_logprobs,
format=format,
options=options,
keep_alive=keep_alive,
).model_dump(exclude_none=True),
stream=stream,
think=think,
logprobs=logprobs,
top_logprobs=top_logprobs,
format=format,
options=options,
keep_alive=keep_alive,
).model_dump(exclude_none=True),
stream=stream,
)
)
# MARK: auto tool execution loop
tool_map = {f.__name__: f for f in (tools or []) if callable(f)}
msgs = list(messages or [])
for _ in range(max_tool_calls):
response = self._request(
ChatResponse, 'POST', '/api/chat',
json=ChatRequest(
model=model, messages=list(_copy_messages(msgs)), tools=list(_copy_tools(tools)),
stream=False, think=think, format=format, options=options, keep_alive=keep_alive,
).model_dump(exclude_none=True), stream=False,
)
if not response.message.tool_calls:
return response
msgs.append(response.message)
for tc in response.message.tool_calls:
output = _exec_tool(tool_map, tc)
msgs.append({'role': 'tool', 'content': output, 'tool_name': tc.function.name})
raise RuntimeError(f'Tool calling exceeded {max_tool_calls} iterations')
def embed(
self,
@ -951,6 +983,7 @@ class AsyncClient(BaseClient):
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
max_tool_calls: Optional[int] = None,
) -> ChatResponse: ...
@overload
@ -967,6 +1000,7 @@ class AsyncClient(BaseClient):
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
max_tool_calls: Optional[int] = None,
) -> AsyncIterator[ChatResponse]: ...
async def chat(
@ -982,6 +1016,7 @@ class AsyncClient(BaseClient):
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
max_tool_calls: Optional[int] = None,
) -> Union[ChatResponse, AsyncIterator[ChatResponse]]:
"""
Create a chat response using the requested model.
@ -993,6 +1028,8 @@ class AsyncClient(BaseClient):
For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
stream: Whether to stream the response.
format: The format of the response.
max_tool_calls: If set to a positive int, automatically execute tool calls in a loop
up to this many iterations. None (default) disables auto-execution.
Example:
def add_two_numbers(a: int, b: int) -> int:
@ -1008,7 +1045,11 @@ class AsyncClient(BaseClient):
'''
return a + b
await client.chat(model='llama3.2', tools=[add_two_numbers], messages=[...])
# Manual tool handling:
await client.chat(model='qwen3.5:4b', tools=[add_two_numbers], messages=[...])
# Auto tool execution (max 10 iterations):
await client.chat(model='qwen3.5:4b', tools=[add_two_numbers], messages=[...], max_tool_calls=10)
Raises `RequestError` if a model is not provided.
@ -1016,25 +1057,47 @@ class AsyncClient(BaseClient):
Returns `ChatResponse` if `stream` is `False`, otherwise returns an asynchronous `ChatResponse` generator.
"""
return await self._request(
ChatResponse,
'POST',
'/api/chat',
json=ChatRequest(
model=model,
messages=list(_copy_messages(messages)),
tools=list(_copy_tools(tools)),
# MARK: standard path (no auto tool execution)
if stream or not max_tool_calls:
return await self._request(
ChatResponse,
'POST',
'/api/chat',
json=ChatRequest(
model=model,
messages=list(_copy_messages(messages)),
tools=list(_copy_tools(tools)),
stream=stream,
think=think,
logprobs=logprobs,
top_logprobs=top_logprobs,
format=format,
options=options,
keep_alive=keep_alive,
).model_dump(exclude_none=True),
stream=stream,
think=think,
logprobs=logprobs,
top_logprobs=top_logprobs,
format=format,
options=options,
keep_alive=keep_alive,
).model_dump(exclude_none=True),
stream=stream,
)
)
# MARK: auto tool execution loop
tool_map = {f.__name__: f for f in (tools or []) if callable(f)}
msgs = list(messages or [])
for _ in range(max_tool_calls):
response = await self._request(
ChatResponse, 'POST', '/api/chat',
json=ChatRequest(
model=model, messages=list(_copy_messages(msgs)), tools=list(_copy_tools(tools)),
stream=False, think=think, format=format, options=options, keep_alive=keep_alive,
).model_dump(exclude_none=True), stream=False,
)
if not response.message.tool_calls:
return response
msgs.append(response.message)
for tc in response.message.tool_calls:
output = _exec_tool(tool_map, tc)
msgs.append({'role': 'tool', 'content': output, 'tool_name': tc.function.name})
raise RuntimeError(f'Tool calling exceeded {max_tool_calls} iterations')
async def embed(
self,
@ -1330,6 +1393,15 @@ def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable
yield convert_function_to_tool(unprocessed_tool) if callable(unprocessed_tool) else Tool.model_validate(unprocessed_tool)
def _exec_tool(tool_map: dict, tc: Message.ToolCall) -> str:
"""Execute a tool call, return result as string."""
fn = tool_map.get(tc.function.name)
if not fn:
return json.dumps({'error': f'Tool {tc.function.name} not found'})
output = fn(**tc.function.arguments)
return output if isinstance(output, str) else json.dumps(output, default=str)
def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:
if isinstance(s, (str, Path)):
try: