From 76ea1d5d2fa3d7daef2b7503e831cdeb8a4fe0ea Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Fri, 22 May 2026 23:21:11 -0500 Subject: [PATCH] [ROCm][CI] Stabilize Granite tool-use and test URL construction (#43017) Signed-off-by: Andreas Karatzas --- examples/tool_chat_template_granite.jinja | 4 +++ tests/tool_use/test_parallel_tool_calls.py | 35 +++++++++++++++++----- tests/tool_use/test_tool_calls.py | 11 +++++-- tests/tool_use/utils.py | 4 +++ tests/utils.py | 3 +- 5 files changed, 46 insertions(+), 11 deletions(-) diff --git a/examples/tool_chat_template_granite.jinja b/examples/tool_chat_template_granite.jinja index 467dcb2d102..834ec1bec48 100644 --- a/examples/tool_chat_template_granite.jinja +++ b/examples/tool_chat_template_granite.jinja @@ -1,4 +1,8 @@ {%- if tools %} + {%- if messages and messages[0]['role'] != 'system' %} + {{- '<|start_of_role|>system<|end_of_role|>You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user query, respond with <|tool_call|> followed by a JSON list of tools used.<|end_of_text|> +' }} + {%- endif %} {{- '<|start_of_role|>available_tools<|end_of_role|> ' }} {%- for tool in tools %} diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py index ed8c80d3667..0f7f6893162 100644 --- a/tests/tool_use/test_parallel_tool_calls.py +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -13,9 +13,19 @@ from .utils import ( SEED, WEATHER_TOOL, ServerConfig, + ensure_system_prompt, ) +def apply_parallel_tool_system_prompt( + messages, + server_config: ServerConfig, +): + if server_config["model"] == "ibm-granite/granite-3.0-8b-instruct": + return ensure_system_prompt(messages, server_config) + return messages + + # test: getting the model to generate parallel tool calls (streaming/not) # when requested. NOTE that not all models may support this, so some exclusions # may be added in the future. e.g. llama 3.1 models are not designed to support @@ -33,8 +43,11 @@ async def test_parallel_tool_calls( models = await client.models.list() model_name: str = models.data[0].id + messages = apply_parallel_tool_system_prompt( + MESSAGES_ASKING_FOR_PARALLEL_TOOLS, server_config + ) chat_completion = await client.chat.completions.create( - messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + messages=messages, temperature=0, max_completion_tokens=200, model=model_name, @@ -73,7 +86,7 @@ async def test_parallel_tool_calls( # make the same request, streaming stream = await client.chat.completions.create( model=model_name, - messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + messages=messages, temperature=0, max_completion_tokens=200, tools=[WEATHER_TOOL, SEARCH_TOOL], @@ -162,8 +175,11 @@ async def test_parallel_tool_calls_with_results( models = await client.models.list() model_name: str = models.data[0].id + messages = apply_parallel_tool_system_prompt( + MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, server_config + ) chat_completion = await client.chat.completions.create( - messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, + messages=messages, temperature=0, max_completion_tokens=200, model=model_name, @@ -182,7 +198,7 @@ async def test_parallel_tool_calls_with_results( assert "78" in choice.message.content # Orlando temp in tool response stream = await client.chat.completions.create( - messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, + messages=messages, temperature=0, max_completion_tokens=200, model=model_name, @@ -220,15 +236,20 @@ async def test_parallel_tool_calls_with_results( @pytest.mark.asyncio -async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI): +async def test_parallel_tool_calls_false( + client: openai.AsyncOpenAI, server_config: ServerConfig +): """ Ensure only one tool call is returned when parallel_tool_calls is False. """ models = await client.models.list() model_name: str = models.data[0].id + messages = apply_parallel_tool_system_prompt( + MESSAGES_ASKING_FOR_PARALLEL_TOOLS, server_config + ) chat_completion = await client.chat.completions.create( - messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + messages=messages, temperature=0, max_completion_tokens=200, model=model_name, @@ -248,7 +269,7 @@ async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI): # make the same request, streaming stream = await client.chat.completions.create( model=model_name, - messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + messages=messages, temperature=0, max_completion_tokens=200, tools=[WEATHER_TOOL, SEARCH_TOOL], diff --git a/tests/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py index f719a886c89..8d21bcd79cc 100644 --- a/tests/tool_use/test_tool_calls.py +++ b/tests/tool_use/test_tool_calls.py @@ -12,17 +12,22 @@ from .utils import ( SEARCH_TOOL, SEED, WEATHER_TOOL, + ServerConfig, + ensure_system_prompt, ) # test: request a chat completion that should return tool calls, so we know they # are parsable @pytest.mark.asyncio -async def test_tool_call_and_choice(client: openai.AsyncOpenAI): +async def test_tool_call_and_choice( + client: openai.AsyncOpenAI, server_config: ServerConfig +): models = await client.models.list() model_name: str = models.data[0].id + messages = ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config) chat_completion = await client.chat.completions.create( - messages=MESSAGES_ASKING_FOR_TOOLS, + messages=messages, temperature=0, max_completion_tokens=100, model=model_name, @@ -68,7 +73,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): # make the same request, streaming stream = await client.chat.completions.create( model=model_name, - messages=MESSAGES_ASKING_FOR_TOOLS, + messages=messages, temperature=0, max_completion_tokens=100, tools=[WEATHER_TOOL, SEARCH_TOOL], diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index 5a03f53ec64..963bc5531c7 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -201,6 +201,10 @@ CONFIGS: dict[str, ServerConfig] = { "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_granite.jinja"), ], + "system_prompt": "You are a helpful AI assistant with access to tools. " + "Use two-letter US state abbreviations in weather tool arguments. " + "When a tool is required to answer the user query, respond with " + "<|tool_call|> followed by a JSON list of tools used.", }, "granite-3.1-8b": { "model": "ibm-granite/granite-3.1-8b-instruct", diff --git a/tests/utils.py b/tests/utils.py index 4b5fb5848f1..fb3bbfd9162 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -660,7 +660,8 @@ class RemoteVLLMServer: ) def url_for(self, *parts: str) -> str: - return self.url_root + "/" + "/".join(parts) + path = "/".join(part.strip("/") for part in parts if part) + return f"{self.url_root}/{path}" def get_client(self, **kwargs): if "timeout" not in kwargs: